Controllable Generation

Controllable Generation

In this notebook, we're going to implement a GAN controllability method using gradients from a classifier. By training a classifier to recognize a relevant feature, we can use it to change the generator's inputs (z-vectors) to make it generate images with more or less of that feature.

We will be started we off with a pre-trained generator and classifier, so that we can focus on the controllability aspects.

The classifier has the same archicture as the earlier critic (remember that the discriminator/critic is simply a classifier used to classify real and fake).

CelebA

Instead of the MNIST dataset, we will be using CelebA. CelebA is a dataset of annotated celebrity images. Since they are colored (not black-and-white), the images have three channels for red, green, and blue (RGB). We'll be using the pre-built pytorch Celeba dataset.

Imports

# python
from argparse import Namespace
from collections import namedtuple
from functools import partial
from pathlib import Path

import pickle

# pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid

import holoviews
import hvplot.pandas
import matplotlib.pyplot as pyplot
import pandas
import torch

# my stuff
from graeae import EmbedHoloviews, Timer

Set Up

The Timer

TIMER = Timer()

The Random Seed

torch.manual_seed(0)

Plotting

SLUG = "controllable-generation"
OUTPUT = f"files/posts/gans/{SLUG}/"

Embed = partial(EmbedHoloviews, folder_path=OUTPUT)

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )

Paths

base_path = Path("~/models/gans/celeba/").expanduser()
assert base_path.is_dir()

prebuilt_models = Namespace(
    celeba = base_path/"pretrained_celeba.pth",
    classifier = base_path/"pretrained_classifier.pth"
)

data_path = Path("~/pytorch-data/").expanduser()
if not data_path.is_dir():
    data_path.mkdir()
assert prebuilt_models.celeba.is_file()
assert prebuilt_models.classifier.is_file()

Helpers

def save_tensor_images(image_tensor: torch.Tensor,
                       filename: str, 
                       title: str,
                       folder: str=f"files/posts/gans{SLUG}/",
                       num_images: int=16, size: tuple=(1, 28, 28), nrow=3):
    """Plot an Image Tensor

    Args:
     image_tensor: tensor with the values for the image to plot
     filename: name to save the file under
     folder: path to put the file in
     title: title for the image
     num_images: how many images from the tensor to use
     size: the dimensions for each image
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    pyplot.title(title)
    pyplot.grid(False)
    pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
    pyplot.tick_params(bottom=False, top=False, labelbottom=False,
                       right=False, left=False, labelleft=False)
    pyplot.savefig(folder + filename)
    print(f"[[file:{filename}]]")
    return

The Generator

This is mostly the same as the other Generators but the images are now color so the channels are different and the model has more initial hidden nodes (and one extra hidden block).

class Generator(nn.Module):
    """Generator for the celeba images

    Args:
       z_dim: the dimension of the noise vector, a scalar
       im_chan: the number of channels in the images, fitted for the dataset used, a scalar
             (CelebA is rgb, so 3 is our default)
       hidden_dim: the inner dimension, a scalar
    """
    def __init__(self, z_dim: int=10, im_chan: int=3, hidden_dim: int=64):
        super().__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels: int, output_channels: int,
                       kernel_size: int=3, stride: int=2,
                       final_layer: bool=False) -> nn.Sequential:
        """Create a sequence of operations corresponding to a generator block of DCGAN

        - a transposed convolution
        - a batchnorm (except in the final layer)
        - an activation.

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)
       Returns:
        sequence of layers
       """
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        """Complete a forward pass of the generator

       Args:
       Parameters:
           noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns:
        generated images.

       """
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

Noise Alias

I still don't get this…

get_noise = torch.randn

Classifier

class Classifier(nn.Module):
    """The Classifier (Discriminator)

    Args:
       im_chan: the number of channels in the images, fitted for the dataset used, a scalar
             (CelebA is rgb, so 3 is our default)
       n_classes: the total number of classes in the dataset, an integer scalar
       hidden_dim: the inner dimension, a scalar
    """
    def __init__(self, im_chan: int=3, n_classes: int=2, hidden_dim: int=64):
        super().__init__()
        self.classifier = nn.Sequential(
            self.make_classifier_block(im_chan, hidden_dim),
            self.make_classifier_block(hidden_dim, hidden_dim * 2),
            self.make_classifier_block(hidden_dim * 2, hidden_dim * 4, stride=3),
            self.make_classifier_block(hidden_dim * 4, n_classes, final_layer=True),
        )

    def make_classifier_block(self, input_channels: int, output_channels: int,
                              kernel_size: int=4, stride: int=2,
                              final_layer: bool=False) -> nn.Sequential:
        """Create a sequence of operations corresponding to a classifier block

        - a convolution
        - a batchnorm (except in the final layer)
        - an activation (except in the final layer).

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)

       Returns:
        Sequence of layers
       """
        if final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Complete a forward pass of the classifier

       Args:
           image: a flattened image tensor with im_chan channels

       Returns:
        an n_classes-dimension tensor representing fake/real.
       """
        class_pred = self.classifier(image)
        return class_pred.view(len(class_pred), -1)

Middle

Specifying Parameters

Before we begin training, we need to specify a few parameters:

  • zdim: the dimension of the noise vector
  • batchsize: the number of images per forward/backward pass
  • device: the device type
z_dim = 64
batch_size = 128
device = 'cuda'

Train a Classifier

Note: the Celeba class will sometimes raise an exception:

Traceback (most recent call last):
  File "/home/neurotic/download_celeba.py", line 27, in <module>
    CelebA(data_path, split='train', download=True, transform=transform),
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/site-packages/torchvision/datasets/celeba.py", line 77, in __init__
    self.download()
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/site-packages/torchvision/datasets/celeba.py", line 131, in download
    with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/zipfile.py", line 1257, in __init__
    self._RealGetContents()
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/zipfile.py", line 1324, in _RealGetContents
    raise BadZipFile("File is not a zip file")
zipfile.BadZipFile: File is not a zip file

According to this bug report the problem is that the files are stored on Google Drive which has a limit on the amount of data you can download per day and if it's been exceeded then when you try to download =imgalignceleba.zip" instead of the zip file you get an HTML page (of the same name) with this message:

Sorry, you can't view or download this file at this time.

Too many users have viewed or downloaded this file recently. Please try accessing the file again later. If the file you are trying to access is particularly large or is shared with many people, it may take up to 24 hours to be able to view or download the file. If you still can't access a file after 24 hours, contact your domain administrator.

The data is available on kaggle so if you download it from them and put the file where the bad file is it should work - except of course, it doesn't. It turns out that some of the text files were also replaced with warnings that the download limit was exceeded so I needed to download those as well, but the files on kaggle are formatted as comma-separated files while the original files are space-separated, but even replacing the commas with spaces won't pass the MD5 check - maybe the line endings are different too? Anyway, the images work so I just waited a day and downloaded the text files from the google drive, which seemed to fix it.

def train_classifier(filename: Path, data_path: Path, epochs: int=3,
                     learning_rate: float=0.001, display_step: int=500,
                     classifier: Classifier=None) -> list:
    """Trains the Classifier

    Args:
     filename: path to save the state-dict to
     data_path: path to the celeba data

    Returns:
     classifier losses
    """
    label_indices = range(40)

    display_step = 500
    beta_1 = 0.5
    beta_2 = 0.999
    image_size = 64
    best_loss = float("inf")
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataloader = DataLoader(
        CelebA(str(data_path), split='train', download=False, transform=transform),
        batch_size=batch_size,
        shuffle=True)
    if classifier is None:
        classifier = Classifier(n_classes=len(label_indices)).to(device)
    class_opt = torch.optim.Adam(classifier.parameters(), lr=learning_rate, betas=(beta_1, beta_2))
    criterion = nn.BCEWithLogitsLoss()

    cur_step = 0
    classifier_losses = []
    # classifier_val_losses = []
    for epoch in range(epochs):
        for real, labels in dataloader:
            real = real.to(device)
            labels = labels[:, label_indices].to(device).float()

            class_opt.zero_grad()
            class_pred = classifier(real)
            class_loss = criterion(class_pred, labels)
            class_loss.backward() # Calculate the gradients
            class_opt.step() # Update the weights
            classifier_losses += [class_loss.item()] # Keep track of the average classifier loss

            ## Visualization code ##
            if classifier_losses[-1] < best_loss:
                torch.save({"classifier": classifier.state_dict()}, filename)
                best_loss = classifier_losses[-1]
            if cur_step % display_step == 0 and cur_step > 0:
                class_mean = sum(classifier_losses[-display_step:]) / display_step
                print(f"Step {cur_step}: Classifier loss: {class_mean}")
                step_bins = 20
            cur_step += 1
    return classifier_losses
classifier_state_dict = Path("~/models/gans/celeba/trained_classifier.pth").expanduser()
with TIMER:
    classifier_losses = train_classifier(classifier_state_dict, data_path, epochs=100)
Started: 2021-05-10 16:52:26.506156
Step 500: Classifier loss: 0.2693843246996403
Step 1000: Classifier loss: 0.24250468423962593
Step 1500: Classifier loss: 0.2307623517513275
Step 2000: Classifier loss: 0.22525288465619087
Step 2500: Classifier loss: 0.22275795283913613
Step 3000: Classifier loss: 0.2154263758957386
Step 3500: Classifier loss: 0.21474265044927596
Step 4000: Classifier loss: 0.21102887812256813
Step 4500: Classifier loss: 0.20789319404959677
Step 5000: Classifier loss: 0.20887315857410432
Step 5500: Classifier loss: 0.20212965056300164
Step 6000: Classifier loss: 0.20280044555664062
Step 6500: Classifier loss: 0.20041452285647393
Step 7000: Classifier loss: 0.19656063199043275
Step 7500: Classifier loss: 0.19845477828383445
Step 8000: Classifier loss: 0.19205777409672736
Step 8500: Classifier loss: 0.19296112078428268
Step 9000: Classifier loss: 0.19257169529795648
Step 9500: Classifier loss: 0.18672975289821625
Step 10000: Classifier loss: 0.18999777460098266
Step 10500: Classifier loss: 0.18432555946707727
Step 11000: Classifier loss: 0.18430076670646667
Step 11500: Classifier loss: 0.18558891993761062
Step 12000: Classifier loss: 0.17852741411328316
Step 12500: Classifier loss: 0.18172698724269867
Step 13000: Classifier loss: 0.1773110607266426
Step 13500: Classifier loss: 0.17672735232114792
Step 14000: Classifier loss: 0.17991324526071548
Step 14500: Classifier loss: 0.17025035387277604
Step 15000: Classifier loss: 0.17529139894247056
Step 15500: Classifier loss: 0.17104978796839715
Step 16000: Classifier loss: 0.17034502020478248
Step 16500: Classifier loss: 0.17325083956122397
Step 17000: Classifier loss: 0.1642009498178959
Step 17500: Classifier loss: 0.16845661264657974
Step 18000: Classifier loss: 0.1664019832611084
Step 18500: Classifier loss: 0.1633680825829506
Step 19000: Classifier loss: 0.16797509816288947
Step 19500: Classifier loss: 0.15925687023997306
Step 20000: Classifier loss: 0.16292004188895226
Step 20500: Classifier loss: 0.16216046965122222
Step 21000: Classifier loss: 0.15743515598773955
Step 21500: Classifier loss: 0.16243972438573837
Step 22000: Classifier loss: 0.1545857997238636
Step 22500: Classifier loss: 0.15797922548651694
Step 23000: Classifier loss: 0.15804208835959435
Step 23500: Classifier loss: 0.15213489854335785
Step 24000: Classifier loss: 0.15730918619036674
Step 24500: Classifier loss: 0.1511693196296692
Step 25000: Classifier loss: 0.1527680770754814
Step 25500: Classifier loss: 0.15510675182938577
Step 26000: Classifier loss: 0.14683012741804122
Step 26500: Classifier loss: 0.15305917632579805
Step 27000: Classifier loss: 0.14754199008643626
Step 27500: Classifier loss: 0.14820717003941536
Step 28000: Classifier loss: 0.15238315638899802
Step 28500: Classifier loss: 0.14171919177472592
Step 29000: Classifier loss: 0.14881789454817773
Step 29500: Classifier loss: 0.1449408364146948
Step 30000: Classifier loss: 0.1441956951916218
Step 30500: Classifier loss: 0.1483478535115719
Step 31000: Classifier loss: 0.13893532317876817
Step 31500: Classifier loss: 0.1450331158787012
Step 32000: Classifier loss: 0.14139907719194889
Step 32500: Classifier loss: 0.1396861730515957
Step 33000: Classifier loss: 0.1451952086240053
Step 33500: Classifier loss: 0.1358419010192156
Step 34000: Classifier loss: 0.14111693547666074
Step 34500: Classifier loss: 0.1400791739821434
Step 35000: Classifier loss: 0.1358947957903147
Step 35500: Classifier loss: 0.14151665523648263
Step 36000: Classifier loss: 0.1336766537129879
Step 36500: Classifier loss: 0.13722201707959175
Step 37000: Classifier loss: 0.1379301232844591
Step 37500: Classifier loss: 0.13219603390991688
Step 38000: Classifier loss: 0.13811730867624283
Step 38500: Classifier loss: 0.13158722695708275
Step 39000: Classifier loss: 0.13359902986884117
Step 39500: Classifier loss: 0.1366793801188469
Step 40000: Classifier loss: 0.12849617034196853
Step 40500: Classifier loss: 0.13549049003422262
Step 41000: Classifier loss: 0.12929423077404498
Step 41500: Classifier loss: 0.13080933578312398
Step 42000: Classifier loss: 0.13500430592894555
Step 42500: Classifier loss: 0.12454062223434448
Step 43000: Classifier loss: 0.13214491476118564
Step 43500: Classifier loss: 0.1284936859458685
Step 44000: Classifier loss: 0.12763021168112754
Step 44500: Classifier loss: 0.13298917169868946
Step 45000: Classifier loss: 0.12208985219895839
Step 45500: Classifier loss: 0.129048362582922
Step 46000: Classifier loss: 0.12678204217553138
Step 46500: Classifier loss: 0.12455842156708241
Step 47000: Classifier loss: 0.1303500325381756
Step 47500: Classifier loss: 0.12025414818525314
Step 48000: Classifier loss: 0.12684993542730807
Step 48500: Classifier loss: 0.1252559674978256
Step 49000: Classifier loss: 0.12153738121688366
Step 49500: Classifier loss: 0.12777481034398078
Step 50000: Classifier loss: 0.118936713129282
Step 50500: Classifier loss: 0.12405500474572181
Ended: 2021-05-10 18:56:36.980805
Elapsed: 2:04:10.474649
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss")()
print(output)

Figure Missing

Take Two

n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(classifier_state_dict, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()
with TIMER:
    classifier_losses = train_classifier(classifier_state_dict, data_path,
                                         epochs=40,
                                         classifier=classifier)
Started: 2021-05-11 16:02:16.181203
Step 500: Classifier loss: 0.1181784438341856
Step 1000: Classifier loss: 0.12448641647398472
Step 1500: Classifier loss: 0.1214247584193945
Step 2000: Classifier loss: 0.1198666417747736
Step 2500: Classifier loss: 0.1255625690817833
Step 3000: Classifier loss: 0.11589906251430511
Step 3500: Classifier loss: 0.12224359685182572
Step 4000: Classifier loss: 0.11944249965250492
Step 4500: Classifier loss: 0.1175859476029873
Step 5000: Classifier loss: 0.12318077574670315
Step 5500: Classifier loss: 0.11450052106380462
Step 6000: Classifier loss: 0.11944048409163951
Step 6500: Classifier loss: 0.11928777326643467
Step 7000: Classifier loss: 0.11463624723255635
Step 7500: Classifier loss: 0.12107200682163238
Step 8000: Classifier loss: 0.11355004295706748
Step 8500: Classifier loss: 0.11719673483073711
Step 9000: Classifier loss: 0.11821492326259612
Step 9500: Classifier loss: 0.11198448015749454
Step 10000: Classifier loss: 0.11870198084414005
Step 10500: Classifier loss: 0.11221958647668362
Step 11000: Classifier loss: 0.11476752410829068
Step 11500: Classifier loss: 0.11772396117448806
Step 12000: Classifier loss: 0.10936097744107247
Step 12500: Classifier loss: 0.11677812692523003
Step 13000: Classifier loss: 0.11107682411372662
Step 13500: Classifier loss: 0.11222303664684295
Step 14000: Classifier loss: 0.11760448211431504
Step 14500: Classifier loss: 0.10662877394258977
Step 15000: Classifier loss: 0.11471863305568696
Step 15500: Classifier loss: 0.11056565625965595
Step 16000: Classifier loss: 0.11046012189984322
Step 16500: Classifier loss: 0.1158019468486309
Step 17000: Classifier loss: 0.10568901741504669
Step 17500: Classifier loss: 0.11223984396457672
Step 18000: Classifier loss: 0.11002579489350318
Step 18500: Classifier loss: 0.10752195838093757
Step 19000: Classifier loss: 0.11419818633794784
Step 19500: Classifier loss: 0.10464896529912948
Step 20000: Classifier loss: 0.11005591739714146
Step 20500: Classifier loss: 0.10996675519645215
Step 21000: Classifier loss: 0.10543355357646943
Step 21500: Classifier loss: 0.11205300988256932
Step 22000: Classifier loss: 0.1038715885579586
Step 22500: Classifier loss: 0.10818033437430859
Step 23000: Classifier loss: 0.10912492156028747
Step 23500: Classifier loss: 0.10302072758972645
Step 24000: Classifier loss: 0.11008756360411644
Step 24500: Classifier loss: 0.10342664630711079
Step 25000: Classifier loss: 0.10618587562441825
Step 25500: Classifier loss: 0.10913233712315559
Step 26000: Classifier loss: 0.10061963592469693
Step 26500: Classifier loss: 0.10828037586808205
Step 27000: Classifier loss: 0.10266246040165425
Step 27500: Classifier loss: 0.1047897623181343
Step 28000: Classifier loss: 0.10866250747442245
Step 28500: Classifier loss: 0.09820086953043938
Step 29000: Classifier loss: 0.10674160474538803
Step 29500: Classifier loss: 0.10230921612679958
Step 30000: Classifier loss: 0.1021555609256029
Step 30500: Classifier loss: 0.10775842162966728
Step 31000: Classifier loss: 0.09722121758759021
Step 31500: Classifier loss: 0.10439497400820255
Step 32000: Classifier loss: 0.10229390095174312
Step 32500: Classifier loss: 0.10003190772235393
Step 33000: Classifier loss: 0.10617333140969276
Step 33500: Classifier loss: 0.09686395044624806
Step 34000: Classifier loss: 0.10285020883381367
Step 34500: Classifier loss: 0.10199978332221508
Step 35000: Classifier loss: 0.09819360673427582
Step 35500: Classifier loss: 0.10397693109512329
Step 36000: Classifier loss: 0.09642438031733036
Step 36500: Classifier loss: 0.10087257397174836
Step 37000: Classifier loss: 0.10197833214700222
Step 37500: Classifier loss: 0.09598418261110783
Step 38000: Classifier loss: 0.10283542364835739
Step 38500: Classifier loss: 0.09644483177363873
Step 39000: Classifier loss: 0.09908602401614189
Step 39500: Classifier loss: 0.10129908196628094
Step 40000: Classifier loss: 0.0939527989178896
Step 40500: Classifier loss: 0.1016722819507122
Step 41000: Classifier loss: 0.09578396078944207
Step 41500: Classifier loss: 0.09706279496848583
Step 42000: Classifier loss: 0.10207961940765381
Step 42500: Classifier loss: 0.09211373472213745
Step 43000: Classifier loss: 0.09958744782209396
Step 43500: Classifier loss: 0.09534277887642384
Step 44000: Classifier loss: 0.0952163600474596
Step 44500: Classifier loss: 0.10136887782812118
Step 45000: Classifier loss: 0.09021547995507717
Step 45500: Classifier loss: 0.09812712541222572
Step 46000: Classifier loss: 0.09560927426815033
Step 46500: Classifier loss: 0.09358323478698731
Step 47000: Classifier loss: 0.09991893386840821
Step 47500: Classifier loss: 0.0899157041311264
Step 48000: Classifier loss: 0.096542285323143
Step 48500: Classifier loss: 0.09535252919793129
Step 49000: Classifier loss: 0.09194727616012097
Step 49500: Classifier loss: 0.09831891848146915
Step 50000: Classifier loss: 0.0901611197590828
Step 50500: Classifier loss: 0.09490065774321556
Ended: 2021-05-11 18:06:19.399986
Elapsed: 2:04:03.218783
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss Session 2", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss_2")()
print(output)

Figure Missing

Take Three

n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(classifier_state_dict, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()
with TIMER:
    classifier_losses = train_classifier(classifier_state_dict, data_path,
                                         epochs=40,
                                         classifier=classifier)
Started: 2021-05-11 21:11:32.420506
Step 500: Classifier loss: 0.09006546361744404
Step 1000: Classifier loss: 0.09647199404239655
Step 1500: Classifier loss: 0.092734768897295
Step 2000: Classifier loss: 0.09196118661761284
Step 2500: Classifier loss: 0.09789373110234738
Step 3000: Classifier loss: 0.08788785541057587
Step 3500: Classifier loss: 0.0945415845811367
Step 4000: Classifier loss: 0.09308994428813458
Step 4500: Classifier loss: 0.09022623193264008
Step 5000: Classifier loss: 0.09615834753215313
Step 5500: Classifier loss: 0.08742603194713593
Step 6000: Classifier loss: 0.09316775412857532
Step 6500: Classifier loss: 0.09275233449041843
Step 7000: Classifier loss: 0.08822614887356758
Step 7500: Classifier loss: 0.09528714890778064
Step 8000: Classifier loss: 0.08681275172531605
Step 8500: Classifier loss: 0.09150236696004868
Step 9000: Classifier loss: 0.09338522186875343
Step 9500: Classifier loss: 0.08638478130102158
Step 10000: Classifier loss: 0.09388372772932052
Step 10500: Classifier loss: 0.08720742921531201
Step 11000: Classifier loss: 0.09009483934938908
Step 11500: Classifier loss: 0.0929495030939579
Step 12000: Classifier loss: 0.08460890363156795
Step 12500: Classifier loss: 0.0924714410007
Step 13000: Classifier loss: 0.08704712373018265
Step 13500: Classifier loss: 0.08819058662652969
Step 14000: Classifier loss: 0.09366303083300591
Step 14500: Classifier loss: 0.08295501434803008
Step 15000: Classifier loss: 0.09084490737318993
Step 15500: Classifier loss: 0.08707242746651173
Step 16000: Classifier loss: 0.08690852355957031
Step 16500: Classifier loss: 0.09254233407974242
Step 17000: Classifier loss: 0.08242024271190167
Step 17500: Classifier loss: 0.08904271678626538
Step 18000: Classifier loss: 0.08771026766300201
Step 18500: Classifier loss: 0.08471861970424652
Step 19000: Classifier loss: 0.09134728060662746
Step 19500: Classifier loss: 0.08233513435721397
Step 20000: Classifier loss: 0.08778411850333213
Step 20500: Classifier loss: 0.08791485584527255
Step 21000: Classifier loss: 0.08345357306301594
Step 21500: Classifier loss: 0.08975999920070171
Step 22000: Classifier loss: 0.08225472408533097
Step 22500: Classifier loss: 0.08668080273270606
Step 23000: Classifier loss: 0.08786206224560737
Step 23500: Classifier loss: 0.08155409483611584
Step 24000: Classifier loss: 0.08907847443222999
Step 24500: Classifier loss: 0.08202618369460106
Step 25000: Classifier loss: 0.08517973597347736
Step 25500: Classifier loss: 0.08817093770205975
Step 26000: Classifier loss: 0.08008052316308022
Step 26500: Classifier loss: 0.08741954331099987
Step 27000: Classifier loss: 0.08247932478785515
Step 27500: Classifier loss: 0.08377225384116173
Step 28000: Classifier loss: 0.08846944206953049
Step 28500: Classifier loss: 0.07859189368784428
Step 29000: Classifier loss: 0.08617163190245629
Step 29500: Classifier loss: 0.0824531610161066
Step 30000: Classifier loss: 0.08195052224397659
Step 30500: Classifier loss: 0.08803890940546989
Step 31000: Classifier loss: 0.07793828934431075
Step 31500: Classifier loss: 0.08464510484039783
Step 32000: Classifier loss: 0.08275749842077494
Step 32500: Classifier loss: 0.0805082704871893
Step 33000: Classifier loss: 0.08703124921023846
Step 33500: Classifier loss: 0.0772736611738801
Step 34000: Classifier loss: 0.08353734220564366
Step 34500: Classifier loss: 0.08343685203790664
Step 35000: Classifier loss: 0.07905932680517436
Step 35500: Classifier loss: 0.08568261863291264
Step 36000: Classifier loss: 0.07762402860075235
Step 36500: Classifier loss: 0.08223582464456558
Step 37000: Classifier loss: 0.08341778349876404
Step 37500: Classifier loss: 0.07801838412880897
Step 38000: Classifier loss: 0.0842266542762518
Step 38500: Classifier loss: 0.07764634099602699
Step 39000: Classifier loss: 0.08104524739086628
Step 39500: Classifier loss: 0.08389902476221323
Step 40000: Classifier loss: 0.07612183248996734
Step 40500: Classifier loss: 0.08296740844845772
Step 41000: Classifier loss: 0.0781253460124135
Step 41500: Classifier loss: 0.07980525248497725
Step 42000: Classifier loss: 0.08405549557507039
Step 42500: Classifier loss: 0.0743530157059431
Step 43000: Classifier loss: 0.08219673927128315
Step 43500: Classifier loss: 0.07845095673948527
Step 44000: Classifier loss: 0.07780187250673772
Step 44500: Classifier loss: 0.08399353076517582
Step 45000: Classifier loss: 0.07365029990673065
Step 45500: Classifier loss: 0.0808380290567875
Step 46000: Classifier loss: 0.0786423703506589
Step 46500: Classifier loss: 0.07693013155460357
Step 47000: Classifier loss: 0.08244228959083558
Step 47500: Classifier loss: 0.07365631985664367
Step 48000: Classifier loss: 0.07970952866971492
Step 48500: Classifier loss: 0.07868134459108114
Step 49000: Classifier loss: 0.07539512529224157
Step 49500: Classifier loss: 0.08191524033248425
Step 50000: Classifier loss: 0.07361406400799751
Step 50500: Classifier loss: 0.07847459720075131
Ended: 2021-05-11 23:15:24.444278
Elapsed: 2:03:52.023772
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss Session 3", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss_3")()
print(output)

Figure Missing

Loading the Pretrained Models

We will then load the pretrained generator and classifier using the following code. (If we trained our own classifier, we can load that one here instead.)

import torch
gen = Generator(z_dim).to(device)
gen_dict = torch.load(prebuilt_models.celeba, map_location=torch.device(device))["gen"]
gen.load_state_dict(gen_dict)
gen.eval()

n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(prebuilt_models.classifier, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()

opt = torch.optim.Adam(classifier.parameters(), lr=0.01)

Training

Now we can start implementing a method for controlling our GAN.

Update Noise

For training, we need to write the code to update the noise to produce more of our desired feature. We do this by performing stochastic gradient ascent. We use stochastic gradient ascent to find the local maxima, as opposed to stochastic gradient descent which finds the local minima. Gradient ascent is gradient descent over the negative of the value being optimized. Their formulas are essentially the same, however, instead of subtracting the weighted value, stochastic gradient ascent adds it; it can be calculated by \(new = old + (∇ old * weight)\), where ∇ is the gradient of old. We perform stochastic gradient ascent to try and maximize the amount of the feature we want. If we wanted to reduce the amount of the feature, we would perform gradient descent. However, in this assignment we are interested in maximize our feature using gradient ascent, since many features in the dataset are not present much more often than they're present and we are trying to add a feature to the images, not remove.

Given the noise with its gradient already calculated through the classifier, we want to return the new noise vector.

  1. Remember the equation for gradient ascent: \(new = old + (∇ old * weight)\).
def calculate_updated_noise(noise: torch.Tensor, weight: float) -> torch.Tensor:
    """Update noise vectors with stochastic gradient ascent.

    Args:
     noise: the current noise vectors. 
           We have already called the backwards function on the target class
           so we can access the gradient of the output class with respect 
           to the noise by using noise.grad
     weight: the scalar amount by which we should weight the noise gradient

    Returns:
     updated noise
    """
    new_noise = noise + (noise.grad * weight)
    return new_noise
  • UNIT TEST

    Check that the basic function works.

    opt.zero_grad()
    noise = torch.ones(20, 20) * 2
    noise.requires_grad_()
    fake_classes = (noise ** 2).mean()
    fake_classes.backward()
    new_noise = calculate_updated_noise(noise, 0.1)
    assert type(new_noise) == torch.Tensor
    assert tuple(new_noise.shape) == (20, 20)
    assert new_noise.max() == 2.0010
    assert new_noise.min() == 2.0010
    assert torch.isclose(new_noise.sum(), torch.tensor(0.4) + 20 * 20 * 2)
    

    Check that it works for generated images

    opt.zero_grad()
    noise = get_noise(32, z_dim).to(device).requires_grad_()
    fake = gen(noise)
    fake_classes = classifier(fake)[:, 0]
    fake_classes.mean().backward()
    noise.data = calculate_updated_noise(noise, 0.01)
    fake = gen(noise)
    fake_classes_new = classifier(fake)[:, 0]
    assert torch.all(fake_classes_new > fake_classes)
    

Generation

Now, we can use the classifier along with stochastic gradient ascent to make noise that generates more of a certain feature. In the code given to us here, we can generate smiling faces. Feel free to change the target index and control some of the other features in the list! We will notice that some features are easier to detect and control than others.

The list we have here are the features labeled in CelebA, which we used to train our classifier. If we wanted to control another feature, we would need to get data that is labeled with that feature and train a classifier on that feature.

First generate a bunch of images with the generator.

n_images = 8
fake_image_history = []
grad_steps = 10 # Number of gradient steps to take
skip = 2 # Number of gradient steps to skip in the visualization

feature_names = ["5oClockShadow", "ArchedEyebrows", "Attractive", "BagsUnderEyes", "Bald", "Bangs",
"BigLips", "BigNose", "BlackHair", "BlondHair", "Blurry", "BrownHair", "BushyEyebrows", "Chubby",
"DoubleChin", "Eyeglasses", "Goatee", "GrayHair", "HeavyMakeup", "HighCheekbones", "Male", 
"MouthSlightlyOpen", "Mustache", "NarrowEyes", "NoBeard", "OvalFace", "PaleSkin", "PointyNose", 
"RecedingHairline", "RosyCheeks", "Sideburn", "Smiling", "StraightHair", "WavyHair", "WearingEarrings", 
"WearingHat", "WearingLipstick", "WearingNecklace", "WearingNecktie", "Weng"]

### Change me! ###
target_indices = feature_names.index("Weng") # Feel free to change this value to any string from feature_names!

noise = get_noise(n_images, z_dim).to(device).requires_grad_()
for i in range(grad_steps):
    opt.zero_grad()
    fake = gen(noise)
    fake_image_history += [fake]
    fake_classes_score = classifier(fake)[:, target_indices].mean()
    fake_classes_score.backward()
    noise.data = calculate_updated_noise(noise, 1 / grad_steps)

pyplot.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
save_tensor_images(image_tensor=torch.cat(fake_image_history[::skip], dim=2), 
filename="weng.png", folder=OUTPUT, title="Weng",
num_images=n_images, nrow=n_images)

weng.png

Entanglement and Regularization

We may also notice that sometimes more features than just the target feature change. This is because some features are entangled. To fix this, we can try to isolate the target feature more by holding the classes outside of the target class constant. One way we can implement this is by penalizing the differences from the original class with L2 regularization. This L2 regularization would apply a penalty for this difference using the L2 norm and this would just be an additional term on the loss function.

Here, we'll have to implement the score function: the higher, the better. The score is calculated by adding the target score and a penalty – note that the penalty is meant to lower the score, so it should have a negative value.

For every non-target class, take the difference between the current noise and the old noise. The greater this value is, the more features outside the target have changed. We will calculate the magnitude of the change, take the mean, and negate it. Finally, add this penalty to the target score. The target score is the mean of the target class in the current noise.

  1. The higher the score, the better!
  2. We want to calculate the loss per image, so we'll need to pass a dim argument to torch.norm.
  3. Calculating the magnitude of the change requires we to take the norm of the difference between the classifications, not the difference of the norms.

Note: torch.norm is deprecated, they want you to use torch.linalg.norm instead.

def get_score(current_classifications: torch.Tensor,
              original_classifications: torch.Tensor,
              target_indices: torch.Tensor,
              other_indices: torch.Tensor,
              penalty_weight: float) -> torch.Tensor:
    """Score the current classifications, L2 Norm penalty

    Args:
       current_classifications: the classifications associated with the current noise
       original_classifications: the classifications associated with the original noise
       target_indices: the index of the target class
       other_indices: the indices of the other classes
       penalty_weight: the amount that the penalty should be weighted in the overall score

    Returns: 
     the score of the current classification with L2 Norm penalty
    """
    # Steps: 1) Calculate the change between the original and current classifications (as a tensor)
    #           by indexing into the other_indices we're trying to preserve, like in x[:, features].
    #        2) Calculate the norm (magnitude) of changes per example.
    #        3) Multiply the mean of the example norms by the penalty weight. 
    #           This will be our other_class_penalty.
    #           Make sure to negate the value since it's a penalty!
    #        4) Take the mean of the current classifications for the target feature over all the examples.
    #           This mean will be our target_score.
    # Calculate the norm (magnitude) of changes per example and multiply by penalty weight
    other_class_penalty = -(torch.mean(
        torch.linalg.norm(original_classifications[:, other_indices]
                          - current_classifications[:, other_indices], dim=1))
                           * penalty_weight)
    # Take the mean of the current classifications for the target feature
    target_score = torch.mean(current_classifications[:, target_indices])
    return target_score + other_class_penalty

UNIT TEST

assert torch.isclose(
    get_score(torch.ones(4, 3), torch.zeros(4, 3), [0], [1, 2], 0.2), 
    1 - torch.sqrt(torch.tensor(2.)) * 0.2
)
rows = 10
current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()

# Must be 3
assert get_score(current_class, original_class, [1, 3] , [0, 2], 0.2).item() == 3

current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[4] * rows, [4] * rows, [2] * rows, [1] * rows]).T.float()

# Must be 3 - 0.2 * sqrt(10)
assert torch.isclose(get_score(current_class, original_class, [1, 3] , [0, 2], 0.2), 
                     -torch.sqrt(torch.tensor(10.0)) * 0.2 + 3)

In the following block of code, we will run the gradient ascent with this new score function. We might notice a few things after running it:

  1. It may fail more often at producing the target feature when compared to the original approach. This suggests that the model may not be able to generate an image that has the target feature without changing the other features. This makes sense! For example, it may not be able to generate a face that's smiling but whose mouth is NOT slightly open. This may also expose a limitation of the generator.

Alternatively, even if the generator can produce an image with the intended features, it might require many intermediate changes to get there and may get stuck in a local minimum.

  1. This process may change features which the classifier was not trained to recognize since there is no way to penalize them with this method. Whether it's possible to train models to avoid changing unsupervised features is an open question.
fake_image_history = []
### Change me! ###
target_indices = feature_names.index("Goatee") # Feel free to change this value to any string from feature_names from earlier!
other_indices = [cur_idx != target_indices for cur_idx, _ in enumerate(feature_names)]
noise = get_noise(n_images, z_dim).to(device).requires_grad_()
original_classifications = classifier(gen(noise)).detach()
for i in range(grad_steps):
    opt.zero_grad()
    fake = gen(noise)
    fake_image_history += [fake]
    fake_score = get_score(
        classifier(fake), 
        original_classifications,
        target_indices,
        other_indices,
        penalty_weight=0.1
    )
    fake_score.backward()
    noise.data = calculate_updated_noise(noise, 1 / grad_steps)

pyplot.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
save_tensor_images(torch.cat(fake_image_history[::skip], dim=2), num_images=n_images, nrow=n_images, filename="goatee.png", folder=OUTPUT, title="Goatee")

goatee.png

End

Sources

  • Liu, Z, Luo, P, Wang, X, Tang, X, Deep Learning Face Attributes in the Wild. In Proceedings of International Conference on Computer Vision (ICCV) 2015 .