Controllable Generation

Controllable Generation

In this notebook, we're going to implement a GAN controllability method using gradients from a classifier. By training a classifier to recognize a relevant feature, we can use it to change the generator's inputs (z-vectors) to make it generate images with more or less of that feature.

We will be started we off with a pre-trained generator and classifier, so that we can focus on the controllability aspects.

The classifier has the same archicture as the earlier critic (remember that the discriminator/critic is simply a classifier used to classify real and fake).

CelebA

Instead of the MNIST dataset, we will be using CelebA. CelebA is a dataset of annotated celebrity images. Since they are colored (not black-and-white), the images have three channels for red, green, and blue (RGB). We'll be using the pre-built pytorch Celeba dataset.

Imports

# python
from argparse import Namespace
from collections import namedtuple
from functools import partial
from pathlib import Path

import pickle

# pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid

import holoviews
import hvplot.pandas
import matplotlib.pyplot as pyplot
import pandas
import torch

# my stuff
from graeae import EmbedHoloviews, Timer

Set Up

The Timer

TIMER = Timer()

The Random Seed

torch.manual_seed(0)

Plotting

SLUG = "controllable-generation"
OUTPUT = f"files/posts/gans/{SLUG}/"

Embed = partial(EmbedHoloviews, folder_path=OUTPUT)

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )

Paths

base_path = Path("~/models/gans/celeba/").expanduser()
assert base_path.is_dir()

prebuilt_models = Namespace(
    celeba = base_path/"pretrained_celeba.pth",
    classifier = base_path/"pretrained_classifier.pth"
)

data_path = Path("~/pytorch-data/").expanduser()
if not data_path.is_dir():
    data_path.mkdir()
assert prebuilt_models.celeba.is_file()
assert prebuilt_models.classifier.is_file()

Helpers

def save_tensor_images(image_tensor: torch.Tensor,
                       filename: str, 
                       title: str,
                       folder: str=f"files/posts/gans{SLUG}/",
                       num_images: int=16, size: tuple=(1, 28, 28), nrow=3):
    """Plot an Image Tensor

    Args:
     image_tensor: tensor with the values for the image to plot
     filename: name to save the file under
     folder: path to put the file in
     title: title for the image
     num_images: how many images from the tensor to use
     size: the dimensions for each image
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    pyplot.title(title)
    pyplot.grid(False)
    pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
    pyplot.tick_params(bottom=False, top=False, labelbottom=False,
                       right=False, left=False, labelleft=False)
    pyplot.savefig(folder + filename)
    print(f"[[file:{filename}]]")
    return

The Generator

This is mostly the same as the other Generators but the images are now color so the channels are different and the model has more initial hidden nodes (and one extra hidden block).

class Generator(nn.Module):
    """Generator for the celeba images

    Args:
       z_dim: the dimension of the noise vector, a scalar
       im_chan: the number of channels in the images, fitted for the dataset used, a scalar
             (CelebA is rgb, so 3 is our default)
       hidden_dim: the inner dimension, a scalar
    """
    def __init__(self, z_dim: int=10, im_chan: int=3, hidden_dim: int=64):
        super().__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels: int, output_channels: int,
                       kernel_size: int=3, stride: int=2,
                       final_layer: bool=False) -> nn.Sequential:
        """Create a sequence of operations corresponding to a generator block of DCGAN

        - a transposed convolution
        - a batchnorm (except in the final layer)
        - an activation.

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)
       Returns:
        sequence of layers
       """
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        """Complete a forward pass of the generator

       Args:
       Parameters:
           noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns:
        generated images.

       """
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

Noise Alias

I still don't get this…

get_noise = torch.randn

Classifier

class Classifier(nn.Module):
    """The Classifier (Discriminator)

    Args:
       im_chan: the number of channels in the images, fitted for the dataset used, a scalar
             (CelebA is rgb, so 3 is our default)
       n_classes: the total number of classes in the dataset, an integer scalar
       hidden_dim: the inner dimension, a scalar
    """
    def __init__(self, im_chan: int=3, n_classes: int=2, hidden_dim: int=64):
        super().__init__()
        self.classifier = nn.Sequential(
            self.make_classifier_block(im_chan, hidden_dim),
            self.make_classifier_block(hidden_dim, hidden_dim * 2),
            self.make_classifier_block(hidden_dim * 2, hidden_dim * 4, stride=3),
            self.make_classifier_block(hidden_dim * 4, n_classes, final_layer=True),
        )

    def make_classifier_block(self, input_channels: int, output_channels: int,
                              kernel_size: int=4, stride: int=2,
                              final_layer: bool=False) -> nn.Sequential:
        """Create a sequence of operations corresponding to a classifier block

        - a convolution
        - a batchnorm (except in the final layer)
        - an activation (except in the final layer).

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)

       Returns:
        Sequence of layers
       """
        if final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Complete a forward pass of the classifier

       Args:
           image: a flattened image tensor with im_chan channels

       Returns:
        an n_classes-dimension tensor representing fake/real.
       """
        class_pred = self.classifier(image)
        return class_pred.view(len(class_pred), -1)

Middle

Specifying Parameters

Before we begin training, we need to specify a few parameters:

  • zdim: the dimension of the noise vector
  • batchsize: the number of images per forward/backward pass
  • device: the device type
z_dim = 64
batch_size = 128
device = 'cuda'

Train a Classifier

Note: the Celeba class will sometimes raise an exception:

Traceback (most recent call last):
  File "/home/neurotic/download_celeba.py", line 27, in <module>
    CelebA(data_path, split='train', download=True, transform=transform),
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/site-packages/torchvision/datasets/celeba.py", line 77, in __init__
    self.download()
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/site-packages/torchvision/datasets/celeba.py", line 131, in download
    with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/zipfile.py", line 1257, in __init__
    self._RealGetContents()
  File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/zipfile.py", line 1324, in _RealGetContents
    raise BadZipFile("File is not a zip file")
zipfile.BadZipFile: File is not a zip file

According to this bug report the problem is that the files are stored on Google Drive which has a limit on the amount of data you can download per day and if it's been exceeded then when you try to download =imgalignceleba.zip" instead of the zip file you get an HTML page (of the same name) with this message:

Sorry, you can't view or download this file at this time.

Too many users have viewed or downloaded this file recently. Please try accessing the file again later. If the file you are trying to access is particularly large or is shared with many people, it may take up to 24 hours to be able to view or download the file. If you still can't access a file after 24 hours, contact your domain administrator.

The data is available on kaggle so if you download it from them and put the file where the bad file is it should work - except of course, it doesn't. It turns out that some of the text files were also replaced with warnings that the download limit was exceeded so I needed to download those as well, but the files on kaggle are formatted as comma-separated files while the original files are space-separated, but even replacing the commas with spaces won't pass the MD5 check - maybe the line endings are different too? Anyway, the images work so I just waited a day and downloaded the text files from the google drive, which seemed to fix it.

def train_classifier(filename: Path, data_path: Path, epochs: int=3,
                     learning_rate: float=0.001, display_step: int=500,
                     classifier: Classifier=None) -> list:
    """Trains the Classifier

    Args:
     filename: path to save the state-dict to
     data_path: path to the celeba data

    Returns:
     classifier losses
    """
    label_indices = range(40)

    display_step = 500
    beta_1 = 0.5
    beta_2 = 0.999
    image_size = 64
    best_loss = float("inf")
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataloader = DataLoader(
        CelebA(str(data_path), split='train', download=False, transform=transform),
        batch_size=batch_size,
        shuffle=True)
    if classifier is None:
        classifier = Classifier(n_classes=len(label_indices)).to(device)
    class_opt = torch.optim.Adam(classifier.parameters(), lr=learning_rate, betas=(beta_1, beta_2))
    criterion = nn.BCEWithLogitsLoss()

    cur_step = 0
    classifier_losses = []
    # classifier_val_losses = []
    for epoch in range(epochs):
        for real, labels in dataloader:
            real = real.to(device)
            labels = labels[:, label_indices].to(device).float()

            class_opt.zero_grad()
            class_pred = classifier(real)
            class_loss = criterion(class_pred, labels)
            class_loss.backward() # Calculate the gradients
            class_opt.step() # Update the weights
            classifier_losses += [class_loss.item()] # Keep track of the average classifier loss

            ## Visualization code ##
            if classifier_losses[-1] < best_loss:
                torch.save({"classifier": classifier.state_dict()}, filename)
                best_loss = classifier_losses[-1]
            if cur_step % display_step == 0 and cur_step > 0:
                class_mean = sum(classifier_losses[-display_step:]) / display_step
                print(f"Step {cur_step}: Classifier loss: {class_mean}")
                step_bins = 20
            cur_step += 1
    return classifier_losses
classifier_state_dict = Path("~/models/gans/celeba/trained_classifier.pth").expanduser()
with TIMER:
    classifier_losses = train_classifier(classifier_state_dict, data_path, epochs=100)
Started: 2021-05-10 16:52:26.506156
Step 500: Classifier loss: 0.2693843246996403
Step 1000: Classifier loss: 0.24250468423962593
Step 1500: Classifier loss: 0.2307623517513275
Step 2000: Classifier loss: 0.22525288465619087
Step 2500: Classifier loss: 0.22275795283913613
Step 3000: Classifier loss: 0.2154263758957386
Step 3500: Classifier loss: 0.21474265044927596
Step 4000: Classifier loss: 0.21102887812256813
Step 4500: Classifier loss: 0.20789319404959677
Step 5000: Classifier loss: 0.20887315857410432
Step 5500: Classifier loss: 0.20212965056300164
Step 6000: Classifier loss: 0.20280044555664062
Step 6500: Classifier loss: 0.20041452285647393
Step 7000: Classifier loss: 0.19656063199043275
Step 7500: Classifier loss: 0.19845477828383445
Step 8000: Classifier loss: 0.19205777409672736
Step 8500: Classifier loss: 0.19296112078428268
Step 9000: Classifier loss: 0.19257169529795648
Step 9500: Classifier loss: 0.18672975289821625
Step 10000: Classifier loss: 0.18999777460098266
Step 10500: Classifier loss: 0.18432555946707727
Step 11000: Classifier loss: 0.18430076670646667
Step 11500: Classifier loss: 0.18558891993761062
Step 12000: Classifier loss: 0.17852741411328316
Step 12500: Classifier loss: 0.18172698724269867
Step 13000: Classifier loss: 0.1773110607266426
Step 13500: Classifier loss: 0.17672735232114792
Step 14000: Classifier loss: 0.17991324526071548
Step 14500: Classifier loss: 0.17025035387277604
Step 15000: Classifier loss: 0.17529139894247056
Step 15500: Classifier loss: 0.17104978796839715
Step 16000: Classifier loss: 0.17034502020478248
Step 16500: Classifier loss: 0.17325083956122397
Step 17000: Classifier loss: 0.1642009498178959
Step 17500: Classifier loss: 0.16845661264657974
Step 18000: Classifier loss: 0.1664019832611084
Step 18500: Classifier loss: 0.1633680825829506
Step 19000: Classifier loss: 0.16797509816288947
Step 19500: Classifier loss: 0.15925687023997306
Step 20000: Classifier loss: 0.16292004188895226
Step 20500: Classifier loss: 0.16216046965122222
Step 21000: Classifier loss: 0.15743515598773955
Step 21500: Classifier loss: 0.16243972438573837
Step 22000: Classifier loss: 0.1545857997238636
Step 22500: Classifier loss: 0.15797922548651694
Step 23000: Classifier loss: 0.15804208835959435
Step 23500: Classifier loss: 0.15213489854335785
Step 24000: Classifier loss: 0.15730918619036674
Step 24500: Classifier loss: 0.1511693196296692
Step 25000: Classifier loss: 0.1527680770754814
Step 25500: Classifier loss: 0.15510675182938577
Step 26000: Classifier loss: 0.14683012741804122
Step 26500: Classifier loss: 0.15305917632579805
Step 27000: Classifier loss: 0.14754199008643626
Step 27500: Classifier loss: 0.14820717003941536
Step 28000: Classifier loss: 0.15238315638899802
Step 28500: Classifier loss: 0.14171919177472592
Step 29000: Classifier loss: 0.14881789454817773
Step 29500: Classifier loss: 0.1449408364146948
Step 30000: Classifier loss: 0.1441956951916218
Step 30500: Classifier loss: 0.1483478535115719
Step 31000: Classifier loss: 0.13893532317876817
Step 31500: Classifier loss: 0.1450331158787012
Step 32000: Classifier loss: 0.14139907719194889
Step 32500: Classifier loss: 0.1396861730515957
Step 33000: Classifier loss: 0.1451952086240053
Step 33500: Classifier loss: 0.1358419010192156
Step 34000: Classifier loss: 0.14111693547666074
Step 34500: Classifier loss: 0.1400791739821434
Step 35000: Classifier loss: 0.1358947957903147
Step 35500: Classifier loss: 0.14151665523648263
Step 36000: Classifier loss: 0.1336766537129879
Step 36500: Classifier loss: 0.13722201707959175
Step 37000: Classifier loss: 0.1379301232844591
Step 37500: Classifier loss: 0.13219603390991688
Step 38000: Classifier loss: 0.13811730867624283
Step 38500: Classifier loss: 0.13158722695708275
Step 39000: Classifier loss: 0.13359902986884117
Step 39500: Classifier loss: 0.1366793801188469
Step 40000: Classifier loss: 0.12849617034196853
Step 40500: Classifier loss: 0.13549049003422262
Step 41000: Classifier loss: 0.12929423077404498
Step 41500: Classifier loss: 0.13080933578312398
Step 42000: Classifier loss: 0.13500430592894555
Step 42500: Classifier loss: 0.12454062223434448
Step 43000: Classifier loss: 0.13214491476118564
Step 43500: Classifier loss: 0.1284936859458685
Step 44000: Classifier loss: 0.12763021168112754
Step 44500: Classifier loss: 0.13298917169868946
Step 45000: Classifier loss: 0.12208985219895839
Step 45500: Classifier loss: 0.129048362582922
Step 46000: Classifier loss: 0.12678204217553138
Step 46500: Classifier loss: 0.12455842156708241
Step 47000: Classifier loss: 0.1303500325381756
Step 47500: Classifier loss: 0.12025414818525314
Step 48000: Classifier loss: 0.12684993542730807
Step 48500: Classifier loss: 0.1252559674978256
Step 49000: Classifier loss: 0.12153738121688366
Step 49500: Classifier loss: 0.12777481034398078
Step 50000: Classifier loss: 0.118936713129282
Step 50500: Classifier loss: 0.12405500474572181
Ended: 2021-05-10 18:56:36.980805
Elapsed: 2:04:10.474649
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss")()
print(output)

Figure Missing

Take Two

n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(classifier_state_dict, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()
with TIMER:
    classifier_losses = train_classifier(classifier_state_dict, data_path,
                                         epochs=40,
                                         classifier=classifier)
Started: 2021-05-11 16:02:16.181203
Step 500: Classifier loss: 0.1181784438341856
Step 1000: Classifier loss: 0.12448641647398472
Step 1500: Classifier loss: 0.1214247584193945
Step 2000: Classifier loss: 0.1198666417747736
Step 2500: Classifier loss: 0.1255625690817833
Step 3000: Classifier loss: 0.11589906251430511
Step 3500: Classifier loss: 0.12224359685182572
Step 4000: Classifier loss: 0.11944249965250492
Step 4500: Classifier loss: 0.1175859476029873
Step 5000: Classifier loss: 0.12318077574670315
Step 5500: Classifier loss: 0.11450052106380462
Step 6000: Classifier loss: 0.11944048409163951
Step 6500: Classifier loss: 0.11928777326643467
Step 7000: Classifier loss: 0.11463624723255635
Step 7500: Classifier loss: 0.12107200682163238
Step 8000: Classifier loss: 0.11355004295706748
Step 8500: Classifier loss: 0.11719673483073711
Step 9000: Classifier loss: 0.11821492326259612
Step 9500: Classifier loss: 0.11198448015749454
Step 10000: Classifier loss: 0.11870198084414005
Step 10500: Classifier loss: 0.11221958647668362
Step 11000: Classifier loss: 0.11476752410829068
Step 11500: Classifier loss: 0.11772396117448806
Step 12000: Classifier loss: 0.10936097744107247
Step 12500: Classifier loss: 0.11677812692523003
Step 13000: Classifier loss: 0.11107682411372662
Step 13500: Classifier loss: 0.11222303664684295
Step 14000: Classifier loss: 0.11760448211431504
Step 14500: Classifier loss: 0.10662877394258977
Step 15000: Classifier loss: 0.11471863305568696
Step 15500: Classifier loss: 0.11056565625965595
Step 16000: Classifier loss: 0.11046012189984322
Step 16500: Classifier loss: 0.1158019468486309
Step 17000: Classifier loss: 0.10568901741504669
Step 17500: Classifier loss: 0.11223984396457672
Step 18000: Classifier loss: 0.11002579489350318
Step 18500: Classifier loss: 0.10752195838093757
Step 19000: Classifier loss: 0.11419818633794784
Step 19500: Classifier loss: 0.10464896529912948
Step 20000: Classifier loss: 0.11005591739714146
Step 20500: Classifier loss: 0.10996675519645215
Step 21000: Classifier loss: 0.10543355357646943
Step 21500: Classifier loss: 0.11205300988256932
Step 22000: Classifier loss: 0.1038715885579586
Step 22500: Classifier loss: 0.10818033437430859
Step 23000: Classifier loss: 0.10912492156028747
Step 23500: Classifier loss: 0.10302072758972645
Step 24000: Classifier loss: 0.11008756360411644
Step 24500: Classifier loss: 0.10342664630711079
Step 25000: Classifier loss: 0.10618587562441825
Step 25500: Classifier loss: 0.10913233712315559
Step 26000: Classifier loss: 0.10061963592469693
Step 26500: Classifier loss: 0.10828037586808205
Step 27000: Classifier loss: 0.10266246040165425
Step 27500: Classifier loss: 0.1047897623181343
Step 28000: Classifier loss: 0.10866250747442245
Step 28500: Classifier loss: 0.09820086953043938
Step 29000: Classifier loss: 0.10674160474538803
Step 29500: Classifier loss: 0.10230921612679958
Step 30000: Classifier loss: 0.1021555609256029
Step 30500: Classifier loss: 0.10775842162966728
Step 31000: Classifier loss: 0.09722121758759021
Step 31500: Classifier loss: 0.10439497400820255
Step 32000: Classifier loss: 0.10229390095174312
Step 32500: Classifier loss: 0.10003190772235393
Step 33000: Classifier loss: 0.10617333140969276
Step 33500: Classifier loss: 0.09686395044624806
Step 34000: Classifier loss: 0.10285020883381367
Step 34500: Classifier loss: 0.10199978332221508
Step 35000: Classifier loss: 0.09819360673427582
Step 35500: Classifier loss: 0.10397693109512329
Step 36000: Classifier loss: 0.09642438031733036
Step 36500: Classifier loss: 0.10087257397174836
Step 37000: Classifier loss: 0.10197833214700222
Step 37500: Classifier loss: 0.09598418261110783
Step 38000: Classifier loss: 0.10283542364835739
Step 38500: Classifier loss: 0.09644483177363873
Step 39000: Classifier loss: 0.09908602401614189
Step 39500: Classifier loss: 0.10129908196628094
Step 40000: Classifier loss: 0.0939527989178896
Step 40500: Classifier loss: 0.1016722819507122
Step 41000: Classifier loss: 0.09578396078944207
Step 41500: Classifier loss: 0.09706279496848583
Step 42000: Classifier loss: 0.10207961940765381
Step 42500: Classifier loss: 0.09211373472213745
Step 43000: Classifier loss: 0.09958744782209396
Step 43500: Classifier loss: 0.09534277887642384
Step 44000: Classifier loss: 0.0952163600474596
Step 44500: Classifier loss: 0.10136887782812118
Step 45000: Classifier loss: 0.09021547995507717
Step 45500: Classifier loss: 0.09812712541222572
Step 46000: Classifier loss: 0.09560927426815033
Step 46500: Classifier loss: 0.09358323478698731
Step 47000: Classifier loss: 0.09991893386840821
Step 47500: Classifier loss: 0.0899157041311264
Step 48000: Classifier loss: 0.096542285323143
Step 48500: Classifier loss: 0.09535252919793129
Step 49000: Classifier loss: 0.09194727616012097
Step 49500: Classifier loss: 0.09831891848146915
Step 50000: Classifier loss: 0.0901611197590828
Step 50500: Classifier loss: 0.09490065774321556
Ended: 2021-05-11 18:06:19.399986
Elapsed: 2:04:03.218783
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss Session 2", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss_2")()
print(output)

Figure Missing

Take Three

n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(classifier_state_dict, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()
with TIMER:
    classifier_losses = train_classifier(classifier_state_dict, data_path,
                                         epochs=40,
                                         classifier=classifier)
Started: 2021-05-11 21:11:32.420506
Step 500: Classifier loss: 0.09006546361744404
Step 1000: Classifier loss: 0.09647199404239655
Step 1500: Classifier loss: 0.092734768897295
Step 2000: Classifier loss: 0.09196118661761284
Step 2500: Classifier loss: 0.09789373110234738
Step 3000: Classifier loss: 0.08788785541057587
Step 3500: Classifier loss: 0.0945415845811367
Step 4000: Classifier loss: 0.09308994428813458
Step 4500: Classifier loss: 0.09022623193264008
Step 5000: Classifier loss: 0.09615834753215313
Step 5500: Classifier loss: 0.08742603194713593
Step 6000: Classifier loss: 0.09316775412857532
Step 6500: Classifier loss: 0.09275233449041843
Step 7000: Classifier loss: 0.08822614887356758
Step 7500: Classifier loss: 0.09528714890778064
Step 8000: Classifier loss: 0.08681275172531605
Step 8500: Classifier loss: 0.09150236696004868
Step 9000: Classifier loss: 0.09338522186875343
Step 9500: Classifier loss: 0.08638478130102158
Step 10000: Classifier loss: 0.09388372772932052
Step 10500: Classifier loss: 0.08720742921531201
Step 11000: Classifier loss: 0.09009483934938908
Step 11500: Classifier loss: 0.0929495030939579
Step 12000: Classifier loss: 0.08460890363156795
Step 12500: Classifier loss: 0.0924714410007
Step 13000: Classifier loss: 0.08704712373018265
Step 13500: Classifier loss: 0.08819058662652969
Step 14000: Classifier loss: 0.09366303083300591
Step 14500: Classifier loss: 0.08295501434803008
Step 15000: Classifier loss: 0.09084490737318993
Step 15500: Classifier loss: 0.08707242746651173
Step 16000: Classifier loss: 0.08690852355957031
Step 16500: Classifier loss: 0.09254233407974242
Step 17000: Classifier loss: 0.08242024271190167
Step 17500: Classifier loss: 0.08904271678626538
Step 18000: Classifier loss: 0.08771026766300201
Step 18500: Classifier loss: 0.08471861970424652
Step 19000: Classifier loss: 0.09134728060662746
Step 19500: Classifier loss: 0.08233513435721397
Step 20000: Classifier loss: 0.08778411850333213
Step 20500: Classifier loss: 0.08791485584527255
Step 21000: Classifier loss: 0.08345357306301594
Step 21500: Classifier loss: 0.08975999920070171
Step 22000: Classifier loss: 0.08225472408533097
Step 22500: Classifier loss: 0.08668080273270606
Step 23000: Classifier loss: 0.08786206224560737
Step 23500: Classifier loss: 0.08155409483611584
Step 24000: Classifier loss: 0.08907847443222999
Step 24500: Classifier loss: 0.08202618369460106
Step 25000: Classifier loss: 0.08517973597347736
Step 25500: Classifier loss: 0.08817093770205975
Step 26000: Classifier loss: 0.08008052316308022
Step 26500: Classifier loss: 0.08741954331099987
Step 27000: Classifier loss: 0.08247932478785515
Step 27500: Classifier loss: 0.08377225384116173
Step 28000: Classifier loss: 0.08846944206953049
Step 28500: Classifier loss: 0.07859189368784428
Step 29000: Classifier loss: 0.08617163190245629
Step 29500: Classifier loss: 0.0824531610161066
Step 30000: Classifier loss: 0.08195052224397659
Step 30500: Classifier loss: 0.08803890940546989
Step 31000: Classifier loss: 0.07793828934431075
Step 31500: Classifier loss: 0.08464510484039783
Step 32000: Classifier loss: 0.08275749842077494
Step 32500: Classifier loss: 0.0805082704871893
Step 33000: Classifier loss: 0.08703124921023846
Step 33500: Classifier loss: 0.0772736611738801
Step 34000: Classifier loss: 0.08353734220564366
Step 34500: Classifier loss: 0.08343685203790664
Step 35000: Classifier loss: 0.07905932680517436
Step 35500: Classifier loss: 0.08568261863291264
Step 36000: Classifier loss: 0.07762402860075235
Step 36500: Classifier loss: 0.08223582464456558
Step 37000: Classifier loss: 0.08341778349876404
Step 37500: Classifier loss: 0.07801838412880897
Step 38000: Classifier loss: 0.0842266542762518
Step 38500: Classifier loss: 0.07764634099602699
Step 39000: Classifier loss: 0.08104524739086628
Step 39500: Classifier loss: 0.08389902476221323
Step 40000: Classifier loss: 0.07612183248996734
Step 40500: Classifier loss: 0.08296740844845772
Step 41000: Classifier loss: 0.0781253460124135
Step 41500: Classifier loss: 0.07980525248497725
Step 42000: Classifier loss: 0.08405549557507039
Step 42500: Classifier loss: 0.0743530157059431
Step 43000: Classifier loss: 0.08219673927128315
Step 43500: Classifier loss: 0.07845095673948527
Step 44000: Classifier loss: 0.07780187250673772
Step 44500: Classifier loss: 0.08399353076517582
Step 45000: Classifier loss: 0.07365029990673065
Step 45500: Classifier loss: 0.0808380290567875
Step 46000: Classifier loss: 0.0786423703506589
Step 46500: Classifier loss: 0.07693013155460357
Step 47000: Classifier loss: 0.08244228959083558
Step 47500: Classifier loss: 0.07365631985664367
Step 48000: Classifier loss: 0.07970952866971492
Step 48500: Classifier loss: 0.07868134459108114
Step 49000: Classifier loss: 0.07539512529224157
Step 49500: Classifier loss: 0.08191524033248425
Step 50000: Classifier loss: 0.07361406400799751
Step 50500: Classifier loss: 0.07847459720075131
Ended: 2021-05-11 23:15:24.444278
Elapsed: 2:03:52.023772
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss Session 3", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss_3")()
print(output)

Figure Missing

Loading the Pretrained Models

We will then load the pretrained generator and classifier using the following code. (If we trained our own classifier, we can load that one here instead.)

import torch
gen = Generator(z_dim).to(device)
gen_dict = torch.load(prebuilt_models.celeba, map_location=torch.device(device))["gen"]
gen.load_state_dict(gen_dict)
gen.eval()

n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(prebuilt_models.classifier, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()

opt = torch.optim.Adam(classifier.parameters(), lr=0.01)

Training

Now we can start implementing a method for controlling our GAN.

Update Noise

For training, we need to write the code to update the noise to produce more of our desired feature. We do this by performing stochastic gradient ascent. We use stochastic gradient ascent to find the local maxima, as opposed to stochastic gradient descent which finds the local minima. Gradient ascent is gradient descent over the negative of the value being optimized. Their formulas are essentially the same, however, instead of subtracting the weighted value, stochastic gradient ascent adds it; it can be calculated by \(new = old + (∇ old * weight)\), where ∇ is the gradient of old. We perform stochastic gradient ascent to try and maximize the amount of the feature we want. If we wanted to reduce the amount of the feature, we would perform gradient descent. However, in this assignment we are interested in maximize our feature using gradient ascent, since many features in the dataset are not present much more often than they're present and we are trying to add a feature to the images, not remove.

Given the noise with its gradient already calculated through the classifier, we want to return the new noise vector.

  1. Remember the equation for gradient ascent: \(new = old + (∇ old * weight)\).
def calculate_updated_noise(noise: torch.Tensor, weight: float) -> torch.Tensor:
    """Update noise vectors with stochastic gradient ascent.

    Args:
     noise: the current noise vectors. 
           We have already called the backwards function on the target class
           so we can access the gradient of the output class with respect 
           to the noise by using noise.grad
     weight: the scalar amount by which we should weight the noise gradient

    Returns:
     updated noise
    """
    new_noise = noise + (noise.grad * weight)
    return new_noise
  • UNIT TEST

    Check that the basic function works.

    opt.zero_grad()
    noise = torch.ones(20, 20) * 2
    noise.requires_grad_()
    fake_classes = (noise ** 2).mean()
    fake_classes.backward()
    new_noise = calculate_updated_noise(noise, 0.1)
    assert type(new_noise) == torch.Tensor
    assert tuple(new_noise.shape) == (20, 20)
    assert new_noise.max() == 2.0010
    assert new_noise.min() == 2.0010
    assert torch.isclose(new_noise.sum(), torch.tensor(0.4) + 20 * 20 * 2)
    

    Check that it works for generated images

    opt.zero_grad()
    noise = get_noise(32, z_dim).to(device).requires_grad_()
    fake = gen(noise)
    fake_classes = classifier(fake)[:, 0]
    fake_classes.mean().backward()
    noise.data = calculate_updated_noise(noise, 0.01)
    fake = gen(noise)
    fake_classes_new = classifier(fake)[:, 0]
    assert torch.all(fake_classes_new > fake_classes)
    

Generation

Now, we can use the classifier along with stochastic gradient ascent to make noise that generates more of a certain feature. In the code given to us here, we can generate smiling faces. Feel free to change the target index and control some of the other features in the list! We will notice that some features are easier to detect and control than others.

The list we have here are the features labeled in CelebA, which we used to train our classifier. If we wanted to control another feature, we would need to get data that is labeled with that feature and train a classifier on that feature.

First generate a bunch of images with the generator.

n_images = 8
fake_image_history = []
grad_steps = 10 # Number of gradient steps to take
skip = 2 # Number of gradient steps to skip in the visualization

feature_names = ["5oClockShadow", "ArchedEyebrows", "Attractive", "BagsUnderEyes", "Bald", "Bangs",
"BigLips", "BigNose", "BlackHair", "BlondHair", "Blurry", "BrownHair", "BushyEyebrows", "Chubby",
"DoubleChin", "Eyeglasses", "Goatee", "GrayHair", "HeavyMakeup", "HighCheekbones", "Male", 
"MouthSlightlyOpen", "Mustache", "NarrowEyes", "NoBeard", "OvalFace", "PaleSkin", "PointyNose", 
"RecedingHairline", "RosyCheeks", "Sideburn", "Smiling", "StraightHair", "WavyHair", "WearingEarrings", 
"WearingHat", "WearingLipstick", "WearingNecklace", "WearingNecktie", "Weng"]

### Change me! ###
target_indices = feature_names.index("Weng") # Feel free to change this value to any string from feature_names!

noise = get_noise(n_images, z_dim).to(device).requires_grad_()
for i in range(grad_steps):
    opt.zero_grad()
    fake = gen(noise)
    fake_image_history += [fake]
    fake_classes_score = classifier(fake)[:, target_indices].mean()
    fake_classes_score.backward()
    noise.data = calculate_updated_noise(noise, 1 / grad_steps)

pyplot.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
save_tensor_images(image_tensor=torch.cat(fake_image_history[::skip], dim=2), 
filename="weng.png", folder=OUTPUT, title="Weng",
num_images=n_images, nrow=n_images)

weng.png

Entanglement and Regularization

We may also notice that sometimes more features than just the target feature change. This is because some features are entangled. To fix this, we can try to isolate the target feature more by holding the classes outside of the target class constant. One way we can implement this is by penalizing the differences from the original class with L2 regularization. This L2 regularization would apply a penalty for this difference using the L2 norm and this would just be an additional term on the loss function.

Here, we'll have to implement the score function: the higher, the better. The score is calculated by adding the target score and a penalty – note that the penalty is meant to lower the score, so it should have a negative value.

For every non-target class, take the difference between the current noise and the old noise. The greater this value is, the more features outside the target have changed. We will calculate the magnitude of the change, take the mean, and negate it. Finally, add this penalty to the target score. The target score is the mean of the target class in the current noise.

  1. The higher the score, the better!
  2. We want to calculate the loss per image, so we'll need to pass a dim argument to torch.norm.
  3. Calculating the magnitude of the change requires we to take the norm of the difference between the classifications, not the difference of the norms.

Note: torch.norm is deprecated, they want you to use torch.linalg.norm instead.

def get_score(current_classifications: torch.Tensor,
              original_classifications: torch.Tensor,
              target_indices: torch.Tensor,
              other_indices: torch.Tensor,
              penalty_weight: float) -> torch.Tensor:
    """Score the current classifications, L2 Norm penalty

    Args:
       current_classifications: the classifications associated with the current noise
       original_classifications: the classifications associated with the original noise
       target_indices: the index of the target class
       other_indices: the indices of the other classes
       penalty_weight: the amount that the penalty should be weighted in the overall score

    Returns: 
     the score of the current classification with L2 Norm penalty
    """
    # Steps: 1) Calculate the change between the original and current classifications (as a tensor)
    #           by indexing into the other_indices we're trying to preserve, like in x[:, features].
    #        2) Calculate the norm (magnitude) of changes per example.
    #        3) Multiply the mean of the example norms by the penalty weight. 
    #           This will be our other_class_penalty.
    #           Make sure to negate the value since it's a penalty!
    #        4) Take the mean of the current classifications for the target feature over all the examples.
    #           This mean will be our target_score.
    # Calculate the norm (magnitude) of changes per example and multiply by penalty weight
    other_class_penalty = -(torch.mean(
        torch.linalg.norm(original_classifications[:, other_indices]
                          - current_classifications[:, other_indices], dim=1))
                           * penalty_weight)
    # Take the mean of the current classifications for the target feature
    target_score = torch.mean(current_classifications[:, target_indices])
    return target_score + other_class_penalty

UNIT TEST

assert torch.isclose(
    get_score(torch.ones(4, 3), torch.zeros(4, 3), [0], [1, 2], 0.2), 
    1 - torch.sqrt(torch.tensor(2.)) * 0.2
)
rows = 10
current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()

# Must be 3
assert get_score(current_class, original_class, [1, 3] , [0, 2], 0.2).item() == 3

current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[4] * rows, [4] * rows, [2] * rows, [1] * rows]).T.float()

# Must be 3 - 0.2 * sqrt(10)
assert torch.isclose(get_score(current_class, original_class, [1, 3] , [0, 2], 0.2), 
                     -torch.sqrt(torch.tensor(10.0)) * 0.2 + 3)

In the following block of code, we will run the gradient ascent with this new score function. We might notice a few things after running it:

  1. It may fail more often at producing the target feature when compared to the original approach. This suggests that the model may not be able to generate an image that has the target feature without changing the other features. This makes sense! For example, it may not be able to generate a face that's smiling but whose mouth is NOT slightly open. This may also expose a limitation of the generator.

Alternatively, even if the generator can produce an image with the intended features, it might require many intermediate changes to get there and may get stuck in a local minimum.

  1. This process may change features which the classifier was not trained to recognize since there is no way to penalize them with this method. Whether it's possible to train models to avoid changing unsupervised features is an open question.
fake_image_history = []
### Change me! ###
target_indices = feature_names.index("Goatee") # Feel free to change this value to any string from feature_names from earlier!
other_indices = [cur_idx != target_indices for cur_idx, _ in enumerate(feature_names)]
noise = get_noise(n_images, z_dim).to(device).requires_grad_()
original_classifications = classifier(gen(noise)).detach()
for i in range(grad_steps):
    opt.zero_grad()
    fake = gen(noise)
    fake_image_history += [fake]
    fake_score = get_score(
        classifier(fake), 
        original_classifications,
        target_indices,
        other_indices,
        penalty_weight=0.1
    )
    fake_score.backward()
    noise.data = calculate_updated_noise(noise, 1 / grad_steps)

pyplot.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
save_tensor_images(torch.cat(fake_image_history[::skip], dim=2), num_images=n_images, nrow=n_images, filename="goatee.png", folder=OUTPUT, title="Goatee")

goatee.png

End

Sources

  • Liu, Z, Luo, P, Wang, X, Tang, X, Deep Learning Face Attributes in the Wild. In Proceedings of International Conference on Computer Vision (ICCV) 2015 .

A Conditional GAN

Build You a Conditional GAN For a Great Good

Imports

# python standard library
from pathlib import Path

import math

# pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

import matplotlib.pyplot as pyplot
import torch
import torch.nn.functional as F

# my stuff
from graeae import Timer

Set Up

The Timer

TIMER = Timer()

The Manual Seed

torch.manual_seed(0)

Plotting

SLUG = "a-conditional-gan"

Helpers

def save_tensor_images(image_tensor: torch.Tensor,
                       filename: str, 
                       title: str,
                       folder: str=f"files/posts/gans{SLUG}",
                       num_images: int=25, size: tuple=(1, 28, 28)):
    """Plot an Image Tensor

    Args:
     image_tensor: tensor with the values for the image to plot
     filename: name to save the file under
     folder: path to put the file in
     title: title for the image
     num_images: how many images from the tensor to use
     size: the dimensions for each image
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    pyplot.title(title)
    pyplot.grid(False)
    pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
    pyplot.tick_params(bottom=False, top=False, labelbottom=False,
                       right=False, left=False, labelleft=False)
    pyplot.savefig(folder + filename)
    print(f"[[file:{filename}]]")
    return

Noise

def make_some_noise(n_samples: int, z_dim: int, device: str='cpu') -> torch.Tensor:
    """Alias for torch.randn

    Args:
      n_samples: the number of samples to generate
      z_dim: the dimension of the noise vector
      device: the device type

    Returns:
     tensor with random numbers from the normal distribution.
    """
    return torch.randn(n_samples, z_dim, device=device)

Middle

The Generator

The Generator and Discriminator are the same ones we used before except the z_dim attribute has been renamed input_dim to reflect the fact that the data is going to be augmented with the classification information.

class Generator(nn.Module):
    """The DCGAN Generator

    Args:
       input_dim: the dimension of the input vector
       im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is your default)
       hidden_dim: the inner dimension,
    """
    def __init__(self, input_dim: int=10, im_chan: int=1, hidden_dim: int=64):
        super().__init__()
        self.input_dim = input_dim

        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels: int, output_channels: int,
                       kernel_size: int=3, stride: int=2,
                       final_layer: bool=False) -> nn.Sequential:
        """Creates a block for the generator (sub sequence)

       The parts
        - a transposed convolution
        - a batchnorm (except for in the last layer)
        - an activation.

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)

       Returns:
        the sub-sequence of layers
       """

        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        """complete a forward pass of the generator: Given a noise tensor, 

       Args:
        noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns:
        generated images.
       """
        # unsqueeze the noise
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

Discriminator

This differs a little from the DCGAN Discriminator in that the initial hidden dimension output goes up to 64 nodes from 16 in the original.

class Discriminator(nn.Module):
    """The DCGAN Discriminator

    Args:
     im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is the default)
     hidden_dim: the inner dimension,
    """
    def __init__(self, im_chan: int=1, hidden_dim: int=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )
        return

    def make_disc_block(self, input_channels: int, output_channels: int,
                        kernel_size: int=4, stride: int=2,
                        final_layer: bool=False) -> nn.Sequential:
        """Make a sub-block of layers for the discriminator

        - a convolution
        - a batchnorm (except for in the last layer)
        - an activation.

       Args:
         input_channels: how many channels the input feature representation has
         output_channels: how many channels the output feature representation should have
         kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
         stride: the stride of the convolution
         final_layer: if true it is the final layer and otherwise not
                     (affects activation and batchnorm)
       """        
        # Build the neural block
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2)
            )
        else: # Final Layer
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Complete a forward pass of the discriminator

       Args:
         image: a flattened image tensor with dimension (im_dim)

       Returns:
        a 1-dimension tensor representing fake/real.
       """
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

The Class Input

One-Hot Encoder

In conditional GANs, the input vector for the generator will also need to include the class information. The class is represented using a one-hot encoded vector where its length is the number of classes and each index represents a class. The vector is all 0's and a 1 on the chosen class. Given the labels of multiple images (e.g. from a batch) and number of classes, please create one-hot vectors for each label. There is a class within the PyTorch functional library that can help you.

  1. This code can be done in one line.
  2. pytorch documentation for F.onehot
def get_one_hot_labels(labels: torch.tensor, n_classes: int) -> torch.Tensor:
    """Create one-hot vectors for the labels

    Args:
       labels: tensor of labels from the dataloader
       n_classes: the total number of classes in the dataset

    Returns:
     a tensor of shape (labels size, num_classes).
    """
    #### START CODE HERE ####
    return F.one_hot(labels, n_classes)
    #### END CODE HERE ####
assert (
    get_one_hot_labels(
        labels=torch.Tensor([[0, 2, 1]]).long(),
        n_classes=3
    ).tolist() == 
    [[
      [1, 0, 0], 
      [0, 0, 1], 
      [0, 1, 0]
    ]]
)

Combine Vectors

Next, you need to be able to concatenate the one-hot class vector to the noise vector before giving it to the generator. You will also need to do this when adding the class channels to the discriminator.

  1. This code can also be written in one line.
  2. See the documentation torch.cat ( Specifically, look at what the dim argument of torch.cat does)
def combine_vectors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Combine two vectors with shapes (n_samples, ?) and (n_samples, ?).

    Args:
      x: the first vector. 
      y: the second vector.
    """
    # Note: Make sure this function outputs a float no matter what inputs it receives
    #### START CODE HERE ####
    combined = torch.cat((x, y), 1).to(torch.float32)
    #### END CODE HERE ####
    return combined
combined = combine_vectors(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]]));
# Check exact order of elements
assert torch.all(combined == torch.tensor([[1, 2, 5, 6], [3, 4, 7, 8]]))
# Tests that items are of float type
assert (type(combined[0][0].item()) == float)
# Check shapes
combined = combine_vectors(torch.randn(1, 4, 5), torch.randn(1, 8, 5));
assert tuple(combined.shape) == (1, 12, 5)
assert tuple(combine_vectors(torch.randn(1, 10, 12).long(), torch.randn(1, 20, 12).long()).shape) == (1, 30, 12)

Training

First, you will define some new parameters:

  • mnistshape: the number of pixels in each MNIST image, which has dimensions 28 x 28 and one channel (because it's black-and-white) so 1 x 28 x 28
  • nclasses: the number of classes in MNIST (10, since there are the digits from 0 to 9)
mnist_shape = (1, 28, 28)
n_classes = 10

And you also include the same parameters from before:

  • criterion: the loss function
  • nepochs: the number of times you iterate through the entire dataset when training
  • zdim: the dimension of the noise vector
  • displaystep: how often to display/visualize the images
  • batchsize: the number of images per forward/backward pass
  • lr: the learning rate
  • device: the device type
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

data_path = Path("~/pytorch-data/MNIST/").expanduser()
dataloader = DataLoader(
    MNIST(data_path, download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Input Dimensions

Then, you can initialize your generator, discriminator, and optimizers. To do this, you will need to update the input dimensions for both models. For the generator, you will need to calculate the size of the input vector; recall that for conditional GANs, the generator's input is the noise vector concatenated with the class vector. For the discriminator, you need to add a channel for every class.

def get_input_dimensions(z_dim: int, mnist_shape: tuple, n_classes: int) -> tuple:
    """Calculates the size of the conditional input dimensions 

    Args:
       z_dim: the dimension of the noise vector
       mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
       n_classes: the total number of classes in the dataset, an integer scalar
               (10 for MNIST)
    Returns: 
       generator_input_dim: the input dimensionality of the conditional generator, 
                         which takes the noise and class vectors
       discriminator_im_chan: the number of input channels to the discriminator
                           (e.g. C x 28 x 28 for MNIST)
    """
    #### START CODE HERE ####
    generator_input_dim = z_dim + n_classes
    discriminator_im_chan = n_classes + mnist_shape[0]
    #### END CODE HERE ####
    return generator_input_dim, discriminator_im_chan
def test_input_dims():
    gen_dim, disc_dim = get_input_dimensions(23, (12, 23, 52), 9)
    assert gen_dim == 32
    assert disc_dim == 21
test_input_dims()

Initialize the Objects

generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
def weights_init(m) -> None:
    """Initialize the weights from the normal distribution

    Args:
     m: object to initialize
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
    return
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

The Training

Now to train, you would like both your generator and your discriminator to know what class of image should be generated.

For example, if you're generating a picture of the number "1", you would need to:

  1. Tell that to the generator, so that it knows it should be generating a "1"
  2. Tell that to the discriminator, so that it knows it should be looking at a "1". If the discriminator is told it should be looking at a 1 but sees something that's clearly an 8, it can guess that it's probably fake
cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False
with TIMER:
    for epoch in range(n_epochs):
        # Dataloader returns the batches and the labels
        for real, labels in dataloader:
            cur_batch_size = len(real)
            # Flatten the batch of real images from the dataset
            real = real.to(device)

            one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
            image_one_hot_labels = one_hot_labels[:, :, None, None]
            image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

            ### Update discriminator ###
            # Zero out the discriminator gradients
            disc_opt.zero_grad()
            # Get noise corresponding to the current batch_size 
            fake_noise = make_some_noise(cur_batch_size, z_dim, device=device)

            # Now you can get the images from the generator
            # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
            #        2) Generate the conditioned fake images

            #### START CODE HERE ####
            noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
            fake = gen(noise_and_labels)
            #### END CODE HERE ####

            # Make sure that enough images were generated
            assert len(fake) == len(real)
            # Check that correct tensors were combined
            assert tuple(noise_and_labels.shape) == (cur_batch_size, fake_noise.shape[1] + one_hot_labels.shape[1])
            # It comes from the correct generator
            assert tuple(fake.shape) == (len(real), 1, 28, 28)

            # Now you can get the predictions from the discriminator
            # Steps: 1) Create the input for the discriminator
            #           a) Combine the fake images with image_one_hot_labels, 
            #              remember to detach the generator (.detach()) so you do not backpropagate through it
            #           b) Combine the real images with image_one_hot_labels
            #        2) Get the discriminator's prediction on the fakes as disc_fake_pred
            #        3) Get the discriminator's prediction on the reals as disc_real_pred

            #### START CODE HERE ####
            fake_image_and_labels = combine_vectors(fake.detach(), image_one_hot_labels)
            real_image_and_labels = combine_vectors(real, image_one_hot_labels)
            disc_fake_pred = disc(fake_image_and_labels)
            disc_real_pred = disc(real_image_and_labels)
            #### END CODE HERE ####

            # Make sure shapes are correct 
            assert tuple(fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], 28 ,28)
            assert tuple(real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1], 28 ,28)
            # Make sure that enough predictions were made
            assert len(disc_real_pred) == len(real)
            # Make sure that the inputs are different
            assert torch.any(fake_image_and_labels != real_image_and_labels)
            # Shapes must match
            assert tuple(fake_image_and_labels.shape) == tuple(real_image_and_labels.shape)
            assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)


            disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
            disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True)
            disc_opt.step() 

            # Keep track of the average discriminator loss
            discriminator_losses += [disc_loss.item()]

            ### Update generator ###
            # Zero out the generator gradients
            gen_opt.zero_grad()

            fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
            # This will error if you didn't concatenate your labels to your image correctly
            disc_fake_pred = disc(fake_image_and_labels)
            gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
            gen_loss.backward()
            gen_opt.step()

            # Keep track of the generator losses
            generator_losses += [gen_loss.item()]
            #

            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                disc_mean = sum(discriminator_losses[-display_step:]) / display_step
                print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
                # show_tensor_images(fake)
                # show_tensor_images(real)
                step_bins = 20
                x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
                # num_examples = (len(generator_losses) // step_bins) * step_bins
                # plt.plot(
                #     range(num_examples // step_bins), 
                #     torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                #     label="Generator Loss"
                # )
                # plt.plot(
                #     range(num_examples // step_bins), 
                #     torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                #     label="Discriminator Loss"
                # )
                # plt.legend()
                # plt.show()
            elif cur_step == 0:
                print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
            cur_step += 1
Started: 2021-04-30 17:21:38.697176
Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!
Step 500: Generator loss: 2.2972163581848144, discriminator loss: 0.24098993314430117
Step 1000: Generator loss: 4.111798384666443, discriminator loss: 0.03910111421905458
Step 1500: Generator loss: 5.055200936317444, discriminator loss: 0.017988519712351263
Step 2000: Generator loss: 4.356978572845459, discriminator loss: 0.07956613468006253
Step 2500: Generator loss: 3.1875410358905794, discriminator loss: 0.1825093053430319
Step 3000: Generator loss: 2.7038124163150785, discriminator loss: 0.26903335136175155
Step 3500: Generator loss: 2.3201852326393126, discriminator loss: 0.30360834433138373
Step 4000: Generator loss: 2.3750923416614533, discriminator loss: 0.3380578280091286
Step 4500: Generator loss: 1.9010865423679353, discriminator loss: 0.3774027600288391
Step 5000: Generator loss: 1.9102082657814026, discriminator loss: 0.39759708976745606
Step 5500: Generator loss: 1.6987447377443314, discriminator loss: 0.43035421246290206
Step 6000: Generator loss: 1.6317225174903869, discriminator loss: 0.44889184486865996
Step 6500: Generator loss: 1.4883701887130738, discriminator loss: 0.47211092388629916
Step 7000: Generator loss: 1.4601867563724518, discriminator loss: 0.49546391534805295
Step 7500: Generator loss: 1.3467793072462082, discriminator loss: 0.520910717189312
Step 8000: Generator loss: 1.3130292971134185, discriminator loss: 0.5447795498371124
Step 8500: Generator loss: 1.1705783107280732, discriminator loss: 0.5716251953244209
Step 9000: Generator loss: 1.119933351278305, discriminator loss: 0.594505407333374
Step 9500: Generator loss: 1.0671374444961548, discriminator loss: 0.6109852703809738
Step 10000: Generator loss: 1.042488064646721, discriminator loss: 0.6221774860620498
Step 10500: Generator loss: 1.015932119846344, discriminator loss: 0.6356924445033073
Step 11000: Generator loss: 0.9900961836576462, discriminator loss: 0.6363092860579491
Step 11500: Generator loss: 0.9709519152641296, discriminator loss: 0.642409072637558
Step 12000: Generator loss: 0.9450125323534012, discriminator loss: 0.6527952731847763
Step 12500: Generator loss: 0.9660063650608063, discriminator loss: 0.6581431583166123
Step 13000: Generator loss: 0.8976931695938111, discriminator loss: 0.6624037518501281
Step 13500: Generator loss: 0.8969441390037537, discriminator loss: 0.6631241027116775
Step 14000: Generator loss: 0.902042475938797, discriminator loss: 0.6669842832088471
Step 14500: Generator loss: 0.8777725909948348, discriminator loss: 0.6742114140987396
Step 15000: Generator loss: 0.8425910116434098, discriminator loss: 0.6705276243686676
Step 15500: Generator loss: 0.8403865963220596, discriminator loss: 0.6768578908443451
Step 16000: Generator loss: 0.8556022493839264, discriminator loss: 0.6762306448221207
Step 16500: Generator loss: 0.8764977214336396, discriminator loss: 0.678820540189743
Step 17000: Generator loss: 0.8166215040683746, discriminator loss: 0.6793025200366973
Step 17500: Generator loss: 0.8172620774507523, discriminator loss: 0.6820640956163406
Step 18000: Generator loss: 0.8534817943572998, discriminator loss: 0.679261458158493
Step 18500: Generator loss: 0.814961371421814, discriminator loss: 0.6819399018287658
Step 19000: Generator loss: 0.8046633821725845, discriminator loss: 0.6841713408231735
Step 19500: Generator loss: 0.8040647978782653, discriminator loss: 0.6816601184606552
Step 20000: Generator loss: 0.8188033536672592, discriminator loss: 0.6830086879730225
Step 20500: Generator loss: 0.8088009203672409, discriminator loss: 0.6811848978996277
Step 21000: Generator loss: 0.7853847659826279, discriminator loss: 0.6836997301578521
Step 21500: Generator loss: 0.7798927721977233, discriminator loss: 0.684486163020134
Step 22000: Generator loss: 0.7833593552112579, discriminator loss: 0.685326477766037
Step 22500: Generator loss: 0.8140243920087814, discriminator loss: 0.6838948290348053
Step 23000: Generator loss: 0.7766883851289749, discriminator loss: 0.6896610424518586
Step 23500: Generator loss: 0.7677333387136459, discriminator loss: 0.6847020034790039
Step 24000: Generator loss: 0.7968860776424408, discriminator loss: 0.6859219465255737
Step 24500: Generator loss: 0.7655997285842896, discriminator loss: 0.6866924543380737
Step 25000: Generator loss: 0.7897603868246078, discriminator loss: 0.6862760508060455
Step 25500: Generator loss: 0.7602986326217651, discriminator loss: 0.6839702378511429
Step 26000: Generator loss: 0.7566630525588989, discriminator loss: 0.6884172073602677
Step 26500: Generator loss: 0.7693239089250564, discriminator loss: 0.6848342741727829
Step 27000: Generator loss: 0.7744117819070816, discriminator loss: 0.6884717727899552
Step 27500: Generator loss: 0.7754857275485992, discriminator loss: 0.6871565765142441
Step 28000: Generator loss: 0.7671164673566818, discriminator loss: 0.691150808095932
Step 28500: Generator loss: 0.7767693866491318, discriminator loss: 0.6899988718032837
Step 29000: Generator loss: 0.7584288560152054, discriminator loss: 0.6866833527088165
Step 29500: Generator loss: 0.7469037870168685, discriminator loss: 0.6892017160654068
Step 30000: Generator loss: 0.7409272351264954, discriminator loss: 0.6925988558530808
Step 30500: Generator loss: 0.7461127021312713, discriminator loss: 0.6910134963989257
Step 31000: Generator loss: 0.7623333480358124, discriminator loss: 0.6848250635862351
Step 31500: Generator loss: 0.7320846046209335, discriminator loss: 0.6922971439361573
Step 32000: Generator loss: 0.7360488106012344, discriminator loss: 0.6958958665132523
Step 32500: Generator loss: 0.7436227219104767, discriminator loss: 0.6927914987802506
Step 33000: Generator loss: 0.7528411923646927, discriminator loss: 0.6868532946109772
Step 33500: Generator loss: 0.7555540499687194, discriminator loss: 0.6819704930782318
Step 34000: Generator loss: 0.7339509303569793, discriminator loss: 0.6947230596542359
Step 34500: Generator loss: 0.7203902735710144, discriminator loss: 0.694019063949585
Step 35000: Generator loss: 0.7161798032522202, discriminator loss: 0.694249948143959
Step 35500: Generator loss: 0.7100930047035218, discriminator loss: 0.6956643009185791
Step 36000: Generator loss: 0.7224245357513428, discriminator loss: 0.692036272764206
Step 36500: Generator loss: 0.7294702612161637, discriminator loss: 0.6839023213386536
Step 37000: Generator loss: 0.7326101566553116, discriminator loss: 0.6864855628013611
Step 37500: Generator loss: 0.7289662526845933, discriminator loss: 0.6891102294921875
Step 38000: Generator loss: 0.7277824294567108, discriminator loss: 0.6930312074422836
Step 38500: Generator loss: 0.7523093600273132, discriminator loss: 0.6822635132074356
Step 39000: Generator loss: 0.7260702294111252, discriminator loss: 0.6836128298044205
Step 39500: Generator loss: 0.7210463825464248, discriminator loss: 0.6865772886276245
Step 40000: Generator loss: 0.7197876414060592, discriminator loss: 0.6861994673013687
Step 40500: Generator loss: 0.7156198496818542, discriminator loss: 0.6897815141677857
Step 41000: Generator loss: 0.7411812788248062, discriminator loss: 0.6876297281980515
Step 41500: Generator loss: 0.7482703533172608, discriminator loss: 0.6831764079332352
Step 42000: Generator loss: 0.7353900390863418, discriminator loss: 0.6809069069623948
Step 42500: Generator loss: 0.726880151629448, discriminator loss: 0.6845077587366104
Step 43000: Generator loss: 0.7335763674974441, discriminator loss: 0.6855163406133652
Step 43500: Generator loss: 0.7247586588859558, discriminator loss: 0.684886796593666
Step 44000: Generator loss: 0.7244187197685241, discriminator loss: 0.6869175283908844
Step 44500: Generator loss: 0.7478935513496399, discriminator loss: 0.6783332238197327
Step 45000: Generator loss: 0.7392684471607208, discriminator loss: 0.687694214463234
Step 45500: Generator loss: 0.7384519840478897, discriminator loss: 0.6806207147836685
Step 46000: Generator loss: 0.7173152709007263, discriminator loss: 0.6894198944568634
Step 46500: Generator loss: 0.7135227386951446, discriminator loss: 0.6902039344310761
Step 47000: Generator loss: 0.7121314022541047, discriminator loss: 0.691226885676384
Step 47500: Generator loss: 0.7153779380321502, discriminator loss: 0.6898772416114807
Step 48000: Generator loss: 0.7112214748859406, discriminator loss: 0.6919035356044769
Step 48500: Generator loss: 0.729472970366478, discriminator loss: 0.6832324341535568
Step 49000: Generator loss: 0.7259864670038223, discriminator loss: 0.6850444099903107
Step 49500: Generator loss: 0.7463545156717301, discriminator loss: 0.6876692290306091
Step 50000: Generator loss: 0.72439306807518, discriminator loss: 0.6840117316246033
Step 50500: Generator loss: 0.7304026707410812, discriminator loss: 0.6828315691947937
Step 51000: Generator loss: 0.735065841794014, discriminator loss: 0.6877049984931946
Step 51500: Generator loss: 0.738693750500679, discriminator loss: 0.6786749280691147
Step 52000: Generator loss: 0.7165734323263169, discriminator loss: 0.688656357049942
Step 52500: Generator loss: 0.7124545810222626, discriminator loss: 0.6885210503339767
Step 53000: Generator loss: 0.7169003388881683, discriminator loss: 0.6898472727537155
Step 53500: Generator loss: 0.7116240389347076, discriminator loss: 0.6890990349054337
Step 54000: Generator loss: 0.7254890002012253, discriminator loss: 0.686066904425621
Step 54500: Generator loss: 0.7279696422815323, discriminator loss: 0.6824959990978241
Step 55000: Generator loss: 0.7243433123826981, discriminator loss: 0.686788556933403
Step 55500: Generator loss: 0.72320248234272, discriminator loss: 0.6819899456501007
Step 56000: Generator loss: 0.7283236463069915, discriminator loss: 0.6813042680025101
Step 56500: Generator loss: 0.7257692145109177, discriminator loss: 0.6882435537576675
Step 57000: Generator loss: 0.7204343225955964, discriminator loss: 0.6905163298845292
Step 57500: Generator loss: 0.7234136379957199, discriminator loss: 0.6828762836456299
Step 58000: Generator loss: 0.7213340125083924, discriminator loss: 0.6852367097139358
Step 58500: Generator loss: 0.7139561972618103, discriminator loss: 0.6901394550800324
Step 59000: Generator loss: 0.7128681792020798, discriminator loss: 0.6899428930282593
Step 59500: Generator loss: 0.7178032584190369, discriminator loss: 0.6901476013660431
Step 60000: Generator loss: 0.7218955677747726, discriminator loss: 0.6866856569051742
Step 60500: Generator loss: 0.7173091459274292, discriminator loss: 0.6909447896480561
Step 61000: Generator loss: 0.7196292532682419, discriminator loss: 0.6888659211397171
Step 61500: Generator loss: 0.7136147793531418, discriminator loss: 0.6911007264852523
Step 62000: Generator loss: 0.7167167031764984, discriminator loss: 0.6874131036996841
Step 62500: Generator loss: 0.7095696296691895, discriminator loss: 0.6924118340015412
Step 63000: Generator loss: 0.7100733149051667, discriminator loss: 0.6894952065944672
Step 63500: Generator loss: 0.7075963083505631, discriminator loss: 0.6918715183734894
Step 64000: Generator loss: 0.7087407541275025, discriminator loss: 0.6912821785211564
Step 64500: Generator loss: 0.7044790136814117, discriminator loss: 0.6919414196014404
Step 65000: Generator loss: 0.7120586842298507, discriminator loss: 0.6889722956418991
Step 65500: Generator loss: 0.7059948451519013, discriminator loss: 0.6913756219148636
Step 66000: Generator loss: 0.7103360829353332, discriminator loss: 0.6888430647850037
Step 66500: Generator loss: 0.7106574136018753, discriminator loss: 0.6923392252922058
Step 67000: Generator loss: 0.7205636972188949, discriminator loss: 0.6888139424324036
Step 67500: Generator loss: 0.7325763144493103, discriminator loss: 0.6851953419446946
Step 68000: Generator loss: 0.7144211075305938, discriminator loss: 0.6894719363451004
Step 68500: Generator loss: 0.7039347168207168, discriminator loss: 0.692310958981514
Step 69000: Generator loss: 0.707789731502533, discriminator loss: 0.690034374833107
Step 69500: Generator loss: 0.7080022550821304, discriminator loss: 0.6897176603078842
Step 70000: Generator loss: 0.706935028553009, discriminator loss: 0.6917025876045227
Step 70500: Generator loss: 0.7035844438076019, discriminator loss: 0.6928271135091781
Step 71000: Generator loss: 0.706664494395256, discriminator loss: 0.6913493415117263
Step 71500: Generator loss: 0.7080443944931031, discriminator loss: 0.6943384435176849
Step 72000: Generator loss: 0.7080535914897919, discriminator loss: 0.6904549078941346
Step 72500: Generator loss: 0.7195642621517181, discriminator loss: 0.6883307158946991
Step 73000: Generator loss: 0.7137477462291717, discriminator loss: 0.6895240060091019
Step 73500: Generator loss: 0.7089026942253113, discriminator loss: 0.6893982688188552
Step 74000: Generator loss: 0.71370064163208, discriminator loss: 0.6885940716266632
Step 74500: Generator loss: 0.7126090573072433, discriminator loss: 0.6913927717208862
Step 75000: Generator loss: 0.7061277792453766, discriminator loss: 0.6915859417915344
Step 75500: Generator loss: 0.7079737706184387, discriminator loss: 0.6918540188074112
Step 76000: Generator loss: 0.7094860315322876, discriminator loss: 0.6909938471317292
Step 76500: Generator loss: 0.7089288998842239, discriminator loss: 0.6928894543647766
Step 77000: Generator loss: 0.7099210443496704, discriminator loss: 0.6881472972631455
Step 77500: Generator loss: 0.7087316303253174, discriminator loss: 0.6922812685966492
Step 78000: Generator loss: 0.7124276860952378, discriminator loss: 0.686549712896347
Step 78500: Generator loss: 0.7118150774240494, discriminator loss: 0.6914097841978073
Step 79000: Generator loss: 0.7061567052602767, discriminator loss: 0.6910830926895142
Step 79500: Generator loss: 0.7130619381666183, discriminator loss: 0.6901648813486099
Step 80000: Generator loss: 0.7189263315200806, discriminator loss: 0.6891602661609649
Step 80500: Generator loss: 0.7099695562124252, discriminator loss: 0.6893361113071441
Step 81000: Generator loss: 0.7043007851839066, discriminator loss: 0.6928421225547791
Step 81500: Generator loss: 0.7055111042261124, discriminator loss: 0.6913482803106308
Step 82000: Generator loss: 0.7107034167051315, discriminator loss: 0.6899493371248245
Step 82500: Generator loss: 0.7072652250528335, discriminator loss: 0.691617219209671
Step 83000: Generator loss: 0.7116999027729034, discriminator loss: 0.6887015362977982
Step 83500: Generator loss: 0.7103397004604339, discriminator loss: 0.6889865008592606
Step 84000: Generator loss: 0.702219740986824, discriminator loss: 0.6927198125123978
Step 84500: Generator loss: 0.7042218887805939, discriminator loss: 0.6910840107202529
Step 85000: Generator loss: 0.7036903632879257, discriminator loss: 0.6948130846023559
Step 85500: Generator loss: 0.7157929112911224, discriminator loss: 0.6877820398807526
Step 86000: Generator loss: 0.703074496269226, discriminator loss: 0.6928336225748062
Step 86500: Generator loss: 0.7018165578842163, discriminator loss: 0.6942198793888092
Step 87000: Generator loss: 0.7056693414449692, discriminator loss: 0.6903062870502472
Step 87500: Generator loss: 0.7070764343738556, discriminator loss: 0.690039452791214
Step 88000: Generator loss: 0.7018579070568085, discriminator loss: 0.6929991252422333
Step 88500: Generator loss: 0.7042791714668274, discriminator loss: 0.6906855113506317
Step 89000: Generator loss: 0.7052908551692962, discriminator loss: 0.6917922974824905
Step 89500: Generator loss: 0.7057228873968124, discriminator loss: 0.6917544984817505
Step 90000: Generator loss: 0.7041428442001343, discriminator loss: 0.6921311345100403
Step 90500: Generator loss: 0.7040341221094132, discriminator loss: 0.6916886166334152
Step 91000: Generator loss: 0.7016080802679062, discriminator loss: 0.6918226274251937
Step 91500: Generator loss: 0.7047490992546082, discriminator loss: 0.6919238156080246
Step 92000: Generator loss: 0.7015802135467529, discriminator loss: 0.6937261604070664
Step 92500: Generator loss: 0.7035572265386582, discriminator loss: 0.6899326649904252
Step 93000: Generator loss: 0.7005916707515717, discriminator loss: 0.6933788905143737
Step 93500: Generator loss: 0.7005794968605041, discriminator loss: 0.6932051202058792
Ended: 2021-04-30 18:15:27.636306
Elapsed: 0:53:48.939130

Exploration

Before you explore, you should put the generator in eval mode, both in general and so that batch norm doesn't cause you issues and is using its eval statistics.

gen = gen.eval()

Changing the Class Vector

You can generate some numbers with your new model! You can add interpolation as well to make it more interesting.

So starting from a image, you will produce intermediate images that look more and more like the ending image until you get to the final image. Your're basically morphing one image into another. You can choose what these two images will be using your conditional GAN.

### Change me! ###

n_interpolation = 9 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)
interpolation_noise = make_some_noise(1, z_dim, device=device).repeat(n_interpolation, 1)
def interpolate_class(first_number, second_number):
    first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
    second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)

    # Calculate the interpolation vector between the two labels
    percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
    interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label

    # Combine the noise and the labels
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
### Change me! ###
start_plot_number = 1 # Choose the start digit
### Change me! ###
end_plot_number = 5 # Choose the end digit

plt.figure(figsize=(8, 8))
interpolate_class(start_plot_number, end_plot_number)
_ = plt.axis('off')

### Uncomment the following lines of code if you would like to visualize a set of pairwise class 
### interpolations for a collection of different numbers, all in a single grid of interpolations.
### You'll also see another visualization like this in the next code block!
# plot_numbers = [2, 3, 4, 5, 7]
# n_numbers = len(plot_numbers)
# plt.figure(figsize=(8, 8))
# for i, first_plot_number in enumerate(plot_numbers):
#     for j, second_plot_number in enumerate(plot_numbers):
#         plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
#         interpolate_class(first_plot_number, second_plot_number)
#         plt.axis('off')
# plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
# plt.show()
# plt.close()

Changing the Noise Vector

Now, what happens if you hold the class constant, but instead you change the noise vector? You can also interpolate the noise vector and generate an image at each step.

n_interpolation = 9 # How many intermediate images you want + 2 (for the start and end image)

This time you're interpolating between the noise instead of the labels

interpolation_label = get_one_hot_labels(torch.Tensor([5]).long(), n_classes).repeat(n_interpolation, 1).float()
def interpolate_noise(first_noise, second_noise):
    # This time you're interpolating between the noise instead of the labels
    percent_first_noise = torch.linspace(0, 1, n_interpolation)[:, None].to(device)
    interpolation_noise = first_noise * percent_first_noise + second_noise * (1 - percent_first_noise)

    # Combine the noise and the labels again
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_label.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)

Generate noise vectors to interpolate between.

### Change me! ###
n_noise = 5 # Choose the number of noise examples in the grid
plot_noises = [get_noise(1, z_dim, device=device) for i in range(n_noise)]
plt.figure(figsize=(8, 8))
for i, first_plot_noise in enumerate(plot_noises):
    for j, second_plot_noise in enumerate(plot_noises):
        plt.subplot(n_noise, n_noise, i * n_noise + j + 1)
        interpolate_noise(first_plot_noise, second_plot_noise)
        plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()

End