A Conditional GAN

Build You a Conditional GAN For a Great Good

Imports

# python standard library
from pathlib import Path

import math

# pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

import matplotlib.pyplot as pyplot
import torch
import torch.nn.functional as F

# my stuff
from graeae import Timer

Set Up

The Timer

TIMER = Timer()

The Manual Seed

torch.manual_seed(0)

Plotting

SLUG = "a-conditional-gan"

Helpers

def save_tensor_images(image_tensor: torch.Tensor,
                       filename: str, 
                       title: str,
                       folder: str=f"files/posts/gans{SLUG}",
                       num_images: int=25, size: tuple=(1, 28, 28)):
    """Plot an Image Tensor

    Args:
     image_tensor: tensor with the values for the image to plot
     filename: name to save the file under
     folder: path to put the file in
     title: title for the image
     num_images: how many images from the tensor to use
     size: the dimensions for each image
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    pyplot.title(title)
    pyplot.grid(False)
    pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
    pyplot.tick_params(bottom=False, top=False, labelbottom=False,
                       right=False, left=False, labelleft=False)
    pyplot.savefig(folder + filename)
    print(f"[[file:{filename}]]")
    return

Noise

def make_some_noise(n_samples: int, z_dim: int, device: str='cpu') -> torch.Tensor:
    """Alias for torch.randn

    Args:
      n_samples: the number of samples to generate
      z_dim: the dimension of the noise vector
      device: the device type

    Returns:
     tensor with random numbers from the normal distribution.
    """
    return torch.randn(n_samples, z_dim, device=device)

Middle

The Generator

The Generator and Discriminator are the same ones we used before except the z_dim attribute has been renamed input_dim to reflect the fact that the data is going to be augmented with the classification information.

class Generator(nn.Module):
    """The DCGAN Generator

    Args:
       input_dim: the dimension of the input vector
       im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is your default)
       hidden_dim: the inner dimension,
    """
    def __init__(self, input_dim: int=10, im_chan: int=1, hidden_dim: int=64):
        super().__init__()
        self.input_dim = input_dim

        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels: int, output_channels: int,
                       kernel_size: int=3, stride: int=2,
                       final_layer: bool=False) -> nn.Sequential:
        """Creates a block for the generator (sub sequence)

       The parts
        - a transposed convolution
        - a batchnorm (except for in the last layer)
        - an activation.

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)

       Returns:
        the sub-sequence of layers
       """

        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        """complete a forward pass of the generator: Given a noise tensor, 

       Args:
        noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns:
        generated images.
       """
        # unsqueeze the noise
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

Discriminator

This differs a little from the DCGAN Discriminator in that the initial hidden dimension output goes up to 64 nodes from 16 in the original.

class Discriminator(nn.Module):
    """The DCGAN Discriminator

    Args:
     im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is the default)
     hidden_dim: the inner dimension,
    """
    def __init__(self, im_chan: int=1, hidden_dim: int=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )
        return

    def make_disc_block(self, input_channels: int, output_channels: int,
                        kernel_size: int=4, stride: int=2,
                        final_layer: bool=False) -> nn.Sequential:
        """Make a sub-block of layers for the discriminator

        - a convolution
        - a batchnorm (except for in the last layer)
        - an activation.

       Args:
         input_channels: how many channels the input feature representation has
         output_channels: how many channels the output feature representation should have
         kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
         stride: the stride of the convolution
         final_layer: if true it is the final layer and otherwise not
                     (affects activation and batchnorm)
       """        
        # Build the neural block
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2)
            )
        else: # Final Layer
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Complete a forward pass of the discriminator

       Args:
         image: a flattened image tensor with dimension (im_dim)

       Returns:
        a 1-dimension tensor representing fake/real.
       """
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

The Class Input

One-Hot Encoder

In conditional GANs, the input vector for the generator will also need to include the class information. The class is represented using a one-hot encoded vector where its length is the number of classes and each index represents a class. The vector is all 0's and a 1 on the chosen class. Given the labels of multiple images (e.g. from a batch) and number of classes, please create one-hot vectors for each label. There is a class within the PyTorch functional library that can help you.

  1. This code can be done in one line.
  2. pytorch documentation for F.onehot
def get_one_hot_labels(labels: torch.tensor, n_classes: int) -> torch.Tensor:
    """Create one-hot vectors for the labels

    Args:
       labels: tensor of labels from the dataloader
       n_classes: the total number of classes in the dataset

    Returns:
     a tensor of shape (labels size, num_classes).
    """
    #### START CODE HERE ####
    return F.one_hot(labels, n_classes)
    #### END CODE HERE ####
assert (
    get_one_hot_labels(
        labels=torch.Tensor([[0, 2, 1]]).long(),
        n_classes=3
    ).tolist() == 
    [[
      [1, 0, 0], 
      [0, 0, 1], 
      [0, 1, 0]
    ]]
)

Combine Vectors

Next, you need to be able to concatenate the one-hot class vector to the noise vector before giving it to the generator. You will also need to do this when adding the class channels to the discriminator.

  1. This code can also be written in one line.
  2. See the documentation torch.cat ( Specifically, look at what the dim argument of torch.cat does)
def combine_vectors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Combine two vectors with shapes (n_samples, ?) and (n_samples, ?).

    Args:
      x: the first vector. 
      y: the second vector.
    """
    # Note: Make sure this function outputs a float no matter what inputs it receives
    #### START CODE HERE ####
    combined = torch.cat((x, y), 1).to(torch.float32)
    #### END CODE HERE ####
    return combined
combined = combine_vectors(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]]));
# Check exact order of elements
assert torch.all(combined == torch.tensor([[1, 2, 5, 6], [3, 4, 7, 8]]))
# Tests that items are of float type
assert (type(combined[0][0].item()) == float)
# Check shapes
combined = combine_vectors(torch.randn(1, 4, 5), torch.randn(1, 8, 5));
assert tuple(combined.shape) == (1, 12, 5)
assert tuple(combine_vectors(torch.randn(1, 10, 12).long(), torch.randn(1, 20, 12).long()).shape) == (1, 30, 12)

Training

First, you will define some new parameters:

  • mnistshape: the number of pixels in each MNIST image, which has dimensions 28 x 28 and one channel (because it's black-and-white) so 1 x 28 x 28
  • nclasses: the number of classes in MNIST (10, since there are the digits from 0 to 9)
mnist_shape = (1, 28, 28)
n_classes = 10

And you also include the same parameters from before:

  • criterion: the loss function
  • nepochs: the number of times you iterate through the entire dataset when training
  • zdim: the dimension of the noise vector
  • displaystep: how often to display/visualize the images
  • batchsize: the number of images per forward/backward pass
  • lr: the learning rate
  • device: the device type
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

data_path = Path("~/pytorch-data/MNIST/").expanduser()
dataloader = DataLoader(
    MNIST(data_path, download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Input Dimensions

Then, you can initialize your generator, discriminator, and optimizers. To do this, you will need to update the input dimensions for both models. For the generator, you will need to calculate the size of the input vector; recall that for conditional GANs, the generator's input is the noise vector concatenated with the class vector. For the discriminator, you need to add a channel for every class.

def get_input_dimensions(z_dim: int, mnist_shape: tuple, n_classes: int) -> tuple:
    """Calculates the size of the conditional input dimensions 

    Args:
       z_dim: the dimension of the noise vector
       mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
       n_classes: the total number of classes in the dataset, an integer scalar
               (10 for MNIST)
    Returns: 
       generator_input_dim: the input dimensionality of the conditional generator, 
                         which takes the noise and class vectors
       discriminator_im_chan: the number of input channels to the discriminator
                           (e.g. C x 28 x 28 for MNIST)
    """
    #### START CODE HERE ####
    generator_input_dim = z_dim + n_classes
    discriminator_im_chan = n_classes + mnist_shape[0]
    #### END CODE HERE ####
    return generator_input_dim, discriminator_im_chan
def test_input_dims():
    gen_dim, disc_dim = get_input_dimensions(23, (12, 23, 52), 9)
    assert gen_dim == 32
    assert disc_dim == 21
test_input_dims()

Initialize the Objects

generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
def weights_init(m) -> None:
    """Initialize the weights from the normal distribution

    Args:
     m: object to initialize
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
    return
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

The Training

Now to train, you would like both your generator and your discriminator to know what class of image should be generated.

For example, if you're generating a picture of the number "1", you would need to:

  1. Tell that to the generator, so that it knows it should be generating a "1"
  2. Tell that to the discriminator, so that it knows it should be looking at a "1". If the discriminator is told it should be looking at a 1 but sees something that's clearly an 8, it can guess that it's probably fake
cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False
with TIMER:
    for epoch in range(n_epochs):
        # Dataloader returns the batches and the labels
        for real, labels in dataloader:
            cur_batch_size = len(real)
            # Flatten the batch of real images from the dataset
            real = real.to(device)

            one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
            image_one_hot_labels = one_hot_labels[:, :, None, None]
            image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

            ### Update discriminator ###
            # Zero out the discriminator gradients
            disc_opt.zero_grad()
            # Get noise corresponding to the current batch_size 
            fake_noise = make_some_noise(cur_batch_size, z_dim, device=device)

            # Now you can get the images from the generator
            # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
            #        2) Generate the conditioned fake images

            #### START CODE HERE ####
            noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
            fake = gen(noise_and_labels)
            #### END CODE HERE ####

            # Make sure that enough images were generated
            assert len(fake) == len(real)
            # Check that correct tensors were combined
            assert tuple(noise_and_labels.shape) == (cur_batch_size, fake_noise.shape[1] + one_hot_labels.shape[1])
            # It comes from the correct generator
            assert tuple(fake.shape) == (len(real), 1, 28, 28)

            # Now you can get the predictions from the discriminator
            # Steps: 1) Create the input for the discriminator
            #           a) Combine the fake images with image_one_hot_labels, 
            #              remember to detach the generator (.detach()) so you do not backpropagate through it
            #           b) Combine the real images with image_one_hot_labels
            #        2) Get the discriminator's prediction on the fakes as disc_fake_pred
            #        3) Get the discriminator's prediction on the reals as disc_real_pred

            #### START CODE HERE ####
            fake_image_and_labels = combine_vectors(fake.detach(), image_one_hot_labels)
            real_image_and_labels = combine_vectors(real, image_one_hot_labels)
            disc_fake_pred = disc(fake_image_and_labels)
            disc_real_pred = disc(real_image_and_labels)
            #### END CODE HERE ####

            # Make sure shapes are correct 
            assert tuple(fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], 28 ,28)
            assert tuple(real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1], 28 ,28)
            # Make sure that enough predictions were made
            assert len(disc_real_pred) == len(real)
            # Make sure that the inputs are different
            assert torch.any(fake_image_and_labels != real_image_and_labels)
            # Shapes must match
            assert tuple(fake_image_and_labels.shape) == tuple(real_image_and_labels.shape)
            assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)


            disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
            disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True)
            disc_opt.step() 

            # Keep track of the average discriminator loss
            discriminator_losses += [disc_loss.item()]

            ### Update generator ###
            # Zero out the generator gradients
            gen_opt.zero_grad()

            fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
            # This will error if you didn't concatenate your labels to your image correctly
            disc_fake_pred = disc(fake_image_and_labels)
            gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
            gen_loss.backward()
            gen_opt.step()

            # Keep track of the generator losses
            generator_losses += [gen_loss.item()]
            #

            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                disc_mean = sum(discriminator_losses[-display_step:]) / display_step
                print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
                # show_tensor_images(fake)
                # show_tensor_images(real)
                step_bins = 20
                x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
                # num_examples = (len(generator_losses) // step_bins) * step_bins
                # plt.plot(
                #     range(num_examples // step_bins), 
                #     torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                #     label="Generator Loss"
                # )
                # plt.plot(
                #     range(num_examples // step_bins), 
                #     torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                #     label="Discriminator Loss"
                # )
                # plt.legend()
                # plt.show()
            elif cur_step == 0:
                print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
            cur_step += 1
Started: 2021-04-30 17:21:38.697176
Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!
Step 500: Generator loss: 2.2972163581848144, discriminator loss: 0.24098993314430117
Step 1000: Generator loss: 4.111798384666443, discriminator loss: 0.03910111421905458
Step 1500: Generator loss: 5.055200936317444, discriminator loss: 0.017988519712351263
Step 2000: Generator loss: 4.356978572845459, discriminator loss: 0.07956613468006253
Step 2500: Generator loss: 3.1875410358905794, discriminator loss: 0.1825093053430319
Step 3000: Generator loss: 2.7038124163150785, discriminator loss: 0.26903335136175155
Step 3500: Generator loss: 2.3201852326393126, discriminator loss: 0.30360834433138373
Step 4000: Generator loss: 2.3750923416614533, discriminator loss: 0.3380578280091286
Step 4500: Generator loss: 1.9010865423679353, discriminator loss: 0.3774027600288391
Step 5000: Generator loss: 1.9102082657814026, discriminator loss: 0.39759708976745606
Step 5500: Generator loss: 1.6987447377443314, discriminator loss: 0.43035421246290206
Step 6000: Generator loss: 1.6317225174903869, discriminator loss: 0.44889184486865996
Step 6500: Generator loss: 1.4883701887130738, discriminator loss: 0.47211092388629916
Step 7000: Generator loss: 1.4601867563724518, discriminator loss: 0.49546391534805295
Step 7500: Generator loss: 1.3467793072462082, discriminator loss: 0.520910717189312
Step 8000: Generator loss: 1.3130292971134185, discriminator loss: 0.5447795498371124
Step 8500: Generator loss: 1.1705783107280732, discriminator loss: 0.5716251953244209
Step 9000: Generator loss: 1.119933351278305, discriminator loss: 0.594505407333374
Step 9500: Generator loss: 1.0671374444961548, discriminator loss: 0.6109852703809738
Step 10000: Generator loss: 1.042488064646721, discriminator loss: 0.6221774860620498
Step 10500: Generator loss: 1.015932119846344, discriminator loss: 0.6356924445033073
Step 11000: Generator loss: 0.9900961836576462, discriminator loss: 0.6363092860579491
Step 11500: Generator loss: 0.9709519152641296, discriminator loss: 0.642409072637558
Step 12000: Generator loss: 0.9450125323534012, discriminator loss: 0.6527952731847763
Step 12500: Generator loss: 0.9660063650608063, discriminator loss: 0.6581431583166123
Step 13000: Generator loss: 0.8976931695938111, discriminator loss: 0.6624037518501281
Step 13500: Generator loss: 0.8969441390037537, discriminator loss: 0.6631241027116775
Step 14000: Generator loss: 0.902042475938797, discriminator loss: 0.6669842832088471
Step 14500: Generator loss: 0.8777725909948348, discriminator loss: 0.6742114140987396
Step 15000: Generator loss: 0.8425910116434098, discriminator loss: 0.6705276243686676
Step 15500: Generator loss: 0.8403865963220596, discriminator loss: 0.6768578908443451
Step 16000: Generator loss: 0.8556022493839264, discriminator loss: 0.6762306448221207
Step 16500: Generator loss: 0.8764977214336396, discriminator loss: 0.678820540189743
Step 17000: Generator loss: 0.8166215040683746, discriminator loss: 0.6793025200366973
Step 17500: Generator loss: 0.8172620774507523, discriminator loss: 0.6820640956163406
Step 18000: Generator loss: 0.8534817943572998, discriminator loss: 0.679261458158493
Step 18500: Generator loss: 0.814961371421814, discriminator loss: 0.6819399018287658
Step 19000: Generator loss: 0.8046633821725845, discriminator loss: 0.6841713408231735
Step 19500: Generator loss: 0.8040647978782653, discriminator loss: 0.6816601184606552
Step 20000: Generator loss: 0.8188033536672592, discriminator loss: 0.6830086879730225
Step 20500: Generator loss: 0.8088009203672409, discriminator loss: 0.6811848978996277
Step 21000: Generator loss: 0.7853847659826279, discriminator loss: 0.6836997301578521
Step 21500: Generator loss: 0.7798927721977233, discriminator loss: 0.684486163020134
Step 22000: Generator loss: 0.7833593552112579, discriminator loss: 0.685326477766037
Step 22500: Generator loss: 0.8140243920087814, discriminator loss: 0.6838948290348053
Step 23000: Generator loss: 0.7766883851289749, discriminator loss: 0.6896610424518586
Step 23500: Generator loss: 0.7677333387136459, discriminator loss: 0.6847020034790039
Step 24000: Generator loss: 0.7968860776424408, discriminator loss: 0.6859219465255737
Step 24500: Generator loss: 0.7655997285842896, discriminator loss: 0.6866924543380737
Step 25000: Generator loss: 0.7897603868246078, discriminator loss: 0.6862760508060455
Step 25500: Generator loss: 0.7602986326217651, discriminator loss: 0.6839702378511429
Step 26000: Generator loss: 0.7566630525588989, discriminator loss: 0.6884172073602677
Step 26500: Generator loss: 0.7693239089250564, discriminator loss: 0.6848342741727829
Step 27000: Generator loss: 0.7744117819070816, discriminator loss: 0.6884717727899552
Step 27500: Generator loss: 0.7754857275485992, discriminator loss: 0.6871565765142441
Step 28000: Generator loss: 0.7671164673566818, discriminator loss: 0.691150808095932
Step 28500: Generator loss: 0.7767693866491318, discriminator loss: 0.6899988718032837
Step 29000: Generator loss: 0.7584288560152054, discriminator loss: 0.6866833527088165
Step 29500: Generator loss: 0.7469037870168685, discriminator loss: 0.6892017160654068
Step 30000: Generator loss: 0.7409272351264954, discriminator loss: 0.6925988558530808
Step 30500: Generator loss: 0.7461127021312713, discriminator loss: 0.6910134963989257
Step 31000: Generator loss: 0.7623333480358124, discriminator loss: 0.6848250635862351
Step 31500: Generator loss: 0.7320846046209335, discriminator loss: 0.6922971439361573
Step 32000: Generator loss: 0.7360488106012344, discriminator loss: 0.6958958665132523
Step 32500: Generator loss: 0.7436227219104767, discriminator loss: 0.6927914987802506
Step 33000: Generator loss: 0.7528411923646927, discriminator loss: 0.6868532946109772
Step 33500: Generator loss: 0.7555540499687194, discriminator loss: 0.6819704930782318
Step 34000: Generator loss: 0.7339509303569793, discriminator loss: 0.6947230596542359
Step 34500: Generator loss: 0.7203902735710144, discriminator loss: 0.694019063949585
Step 35000: Generator loss: 0.7161798032522202, discriminator loss: 0.694249948143959
Step 35500: Generator loss: 0.7100930047035218, discriminator loss: 0.6956643009185791
Step 36000: Generator loss: 0.7224245357513428, discriminator loss: 0.692036272764206
Step 36500: Generator loss: 0.7294702612161637, discriminator loss: 0.6839023213386536
Step 37000: Generator loss: 0.7326101566553116, discriminator loss: 0.6864855628013611
Step 37500: Generator loss: 0.7289662526845933, discriminator loss: 0.6891102294921875
Step 38000: Generator loss: 0.7277824294567108, discriminator loss: 0.6930312074422836
Step 38500: Generator loss: 0.7523093600273132, discriminator loss: 0.6822635132074356
Step 39000: Generator loss: 0.7260702294111252, discriminator loss: 0.6836128298044205
Step 39500: Generator loss: 0.7210463825464248, discriminator loss: 0.6865772886276245
Step 40000: Generator loss: 0.7197876414060592, discriminator loss: 0.6861994673013687
Step 40500: Generator loss: 0.7156198496818542, discriminator loss: 0.6897815141677857
Step 41000: Generator loss: 0.7411812788248062, discriminator loss: 0.6876297281980515
Step 41500: Generator loss: 0.7482703533172608, discriminator loss: 0.6831764079332352
Step 42000: Generator loss: 0.7353900390863418, discriminator loss: 0.6809069069623948
Step 42500: Generator loss: 0.726880151629448, discriminator loss: 0.6845077587366104
Step 43000: Generator loss: 0.7335763674974441, discriminator loss: 0.6855163406133652
Step 43500: Generator loss: 0.7247586588859558, discriminator loss: 0.684886796593666
Step 44000: Generator loss: 0.7244187197685241, discriminator loss: 0.6869175283908844
Step 44500: Generator loss: 0.7478935513496399, discriminator loss: 0.6783332238197327
Step 45000: Generator loss: 0.7392684471607208, discriminator loss: 0.687694214463234
Step 45500: Generator loss: 0.7384519840478897, discriminator loss: 0.6806207147836685
Step 46000: Generator loss: 0.7173152709007263, discriminator loss: 0.6894198944568634
Step 46500: Generator loss: 0.7135227386951446, discriminator loss: 0.6902039344310761
Step 47000: Generator loss: 0.7121314022541047, discriminator loss: 0.691226885676384
Step 47500: Generator loss: 0.7153779380321502, discriminator loss: 0.6898772416114807
Step 48000: Generator loss: 0.7112214748859406, discriminator loss: 0.6919035356044769
Step 48500: Generator loss: 0.729472970366478, discriminator loss: 0.6832324341535568
Step 49000: Generator loss: 0.7259864670038223, discriminator loss: 0.6850444099903107
Step 49500: Generator loss: 0.7463545156717301, discriminator loss: 0.6876692290306091
Step 50000: Generator loss: 0.72439306807518, discriminator loss: 0.6840117316246033
Step 50500: Generator loss: 0.7304026707410812, discriminator loss: 0.6828315691947937
Step 51000: Generator loss: 0.735065841794014, discriminator loss: 0.6877049984931946
Step 51500: Generator loss: 0.738693750500679, discriminator loss: 0.6786749280691147
Step 52000: Generator loss: 0.7165734323263169, discriminator loss: 0.688656357049942
Step 52500: Generator loss: 0.7124545810222626, discriminator loss: 0.6885210503339767
Step 53000: Generator loss: 0.7169003388881683, discriminator loss: 0.6898472727537155
Step 53500: Generator loss: 0.7116240389347076, discriminator loss: 0.6890990349054337
Step 54000: Generator loss: 0.7254890002012253, discriminator loss: 0.686066904425621
Step 54500: Generator loss: 0.7279696422815323, discriminator loss: 0.6824959990978241
Step 55000: Generator loss: 0.7243433123826981, discriminator loss: 0.686788556933403
Step 55500: Generator loss: 0.72320248234272, discriminator loss: 0.6819899456501007
Step 56000: Generator loss: 0.7283236463069915, discriminator loss: 0.6813042680025101
Step 56500: Generator loss: 0.7257692145109177, discriminator loss: 0.6882435537576675
Step 57000: Generator loss: 0.7204343225955964, discriminator loss: 0.6905163298845292
Step 57500: Generator loss: 0.7234136379957199, discriminator loss: 0.6828762836456299
Step 58000: Generator loss: 0.7213340125083924, discriminator loss: 0.6852367097139358
Step 58500: Generator loss: 0.7139561972618103, discriminator loss: 0.6901394550800324
Step 59000: Generator loss: 0.7128681792020798, discriminator loss: 0.6899428930282593
Step 59500: Generator loss: 0.7178032584190369, discriminator loss: 0.6901476013660431
Step 60000: Generator loss: 0.7218955677747726, discriminator loss: 0.6866856569051742
Step 60500: Generator loss: 0.7173091459274292, discriminator loss: 0.6909447896480561
Step 61000: Generator loss: 0.7196292532682419, discriminator loss: 0.6888659211397171
Step 61500: Generator loss: 0.7136147793531418, discriminator loss: 0.6911007264852523
Step 62000: Generator loss: 0.7167167031764984, discriminator loss: 0.6874131036996841
Step 62500: Generator loss: 0.7095696296691895, discriminator loss: 0.6924118340015412
Step 63000: Generator loss: 0.7100733149051667, discriminator loss: 0.6894952065944672
Step 63500: Generator loss: 0.7075963083505631, discriminator loss: 0.6918715183734894
Step 64000: Generator loss: 0.7087407541275025, discriminator loss: 0.6912821785211564
Step 64500: Generator loss: 0.7044790136814117, discriminator loss: 0.6919414196014404
Step 65000: Generator loss: 0.7120586842298507, discriminator loss: 0.6889722956418991
Step 65500: Generator loss: 0.7059948451519013, discriminator loss: 0.6913756219148636
Step 66000: Generator loss: 0.7103360829353332, discriminator loss: 0.6888430647850037
Step 66500: Generator loss: 0.7106574136018753, discriminator loss: 0.6923392252922058
Step 67000: Generator loss: 0.7205636972188949, discriminator loss: 0.6888139424324036
Step 67500: Generator loss: 0.7325763144493103, discriminator loss: 0.6851953419446946
Step 68000: Generator loss: 0.7144211075305938, discriminator loss: 0.6894719363451004
Step 68500: Generator loss: 0.7039347168207168, discriminator loss: 0.692310958981514
Step 69000: Generator loss: 0.707789731502533, discriminator loss: 0.690034374833107
Step 69500: Generator loss: 0.7080022550821304, discriminator loss: 0.6897176603078842
Step 70000: Generator loss: 0.706935028553009, discriminator loss: 0.6917025876045227
Step 70500: Generator loss: 0.7035844438076019, discriminator loss: 0.6928271135091781
Step 71000: Generator loss: 0.706664494395256, discriminator loss: 0.6913493415117263
Step 71500: Generator loss: 0.7080443944931031, discriminator loss: 0.6943384435176849
Step 72000: Generator loss: 0.7080535914897919, discriminator loss: 0.6904549078941346
Step 72500: Generator loss: 0.7195642621517181, discriminator loss: 0.6883307158946991
Step 73000: Generator loss: 0.7137477462291717, discriminator loss: 0.6895240060091019
Step 73500: Generator loss: 0.7089026942253113, discriminator loss: 0.6893982688188552
Step 74000: Generator loss: 0.71370064163208, discriminator loss: 0.6885940716266632
Step 74500: Generator loss: 0.7126090573072433, discriminator loss: 0.6913927717208862
Step 75000: Generator loss: 0.7061277792453766, discriminator loss: 0.6915859417915344
Step 75500: Generator loss: 0.7079737706184387, discriminator loss: 0.6918540188074112
Step 76000: Generator loss: 0.7094860315322876, discriminator loss: 0.6909938471317292
Step 76500: Generator loss: 0.7089288998842239, discriminator loss: 0.6928894543647766
Step 77000: Generator loss: 0.7099210443496704, discriminator loss: 0.6881472972631455
Step 77500: Generator loss: 0.7087316303253174, discriminator loss: 0.6922812685966492
Step 78000: Generator loss: 0.7124276860952378, discriminator loss: 0.686549712896347
Step 78500: Generator loss: 0.7118150774240494, discriminator loss: 0.6914097841978073
Step 79000: Generator loss: 0.7061567052602767, discriminator loss: 0.6910830926895142
Step 79500: Generator loss: 0.7130619381666183, discriminator loss: 0.6901648813486099
Step 80000: Generator loss: 0.7189263315200806, discriminator loss: 0.6891602661609649
Step 80500: Generator loss: 0.7099695562124252, discriminator loss: 0.6893361113071441
Step 81000: Generator loss: 0.7043007851839066, discriminator loss: 0.6928421225547791
Step 81500: Generator loss: 0.7055111042261124, discriminator loss: 0.6913482803106308
Step 82000: Generator loss: 0.7107034167051315, discriminator loss: 0.6899493371248245
Step 82500: Generator loss: 0.7072652250528335, discriminator loss: 0.691617219209671
Step 83000: Generator loss: 0.7116999027729034, discriminator loss: 0.6887015362977982
Step 83500: Generator loss: 0.7103397004604339, discriminator loss: 0.6889865008592606
Step 84000: Generator loss: 0.702219740986824, discriminator loss: 0.6927198125123978
Step 84500: Generator loss: 0.7042218887805939, discriminator loss: 0.6910840107202529
Step 85000: Generator loss: 0.7036903632879257, discriminator loss: 0.6948130846023559
Step 85500: Generator loss: 0.7157929112911224, discriminator loss: 0.6877820398807526
Step 86000: Generator loss: 0.703074496269226, discriminator loss: 0.6928336225748062
Step 86500: Generator loss: 0.7018165578842163, discriminator loss: 0.6942198793888092
Step 87000: Generator loss: 0.7056693414449692, discriminator loss: 0.6903062870502472
Step 87500: Generator loss: 0.7070764343738556, discriminator loss: 0.690039452791214
Step 88000: Generator loss: 0.7018579070568085, discriminator loss: 0.6929991252422333
Step 88500: Generator loss: 0.7042791714668274, discriminator loss: 0.6906855113506317
Step 89000: Generator loss: 0.7052908551692962, discriminator loss: 0.6917922974824905
Step 89500: Generator loss: 0.7057228873968124, discriminator loss: 0.6917544984817505
Step 90000: Generator loss: 0.7041428442001343, discriminator loss: 0.6921311345100403
Step 90500: Generator loss: 0.7040341221094132, discriminator loss: 0.6916886166334152
Step 91000: Generator loss: 0.7016080802679062, discriminator loss: 0.6918226274251937
Step 91500: Generator loss: 0.7047490992546082, discriminator loss: 0.6919238156080246
Step 92000: Generator loss: 0.7015802135467529, discriminator loss: 0.6937261604070664
Step 92500: Generator loss: 0.7035572265386582, discriminator loss: 0.6899326649904252
Step 93000: Generator loss: 0.7005916707515717, discriminator loss: 0.6933788905143737
Step 93500: Generator loss: 0.7005794968605041, discriminator loss: 0.6932051202058792
Ended: 2021-04-30 18:15:27.636306
Elapsed: 0:53:48.939130

Exploration

Before you explore, you should put the generator in eval mode, both in general and so that batch norm doesn't cause you issues and is using its eval statistics.

gen = gen.eval()

Changing the Class Vector

You can generate some numbers with your new model! You can add interpolation as well to make it more interesting.

So starting from a image, you will produce intermediate images that look more and more like the ending image until you get to the final image. Your're basically morphing one image into another. You can choose what these two images will be using your conditional GAN.

### Change me! ###

n_interpolation = 9 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)
interpolation_noise = make_some_noise(1, z_dim, device=device).repeat(n_interpolation, 1)
def interpolate_class(first_number, second_number):
    first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
    second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)

    # Calculate the interpolation vector between the two labels
    percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
    interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label

    # Combine the noise and the labels
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
### Change me! ###
start_plot_number = 1 # Choose the start digit
### Change me! ###
end_plot_number = 5 # Choose the end digit

plt.figure(figsize=(8, 8))
interpolate_class(start_plot_number, end_plot_number)
_ = plt.axis('off')

### Uncomment the following lines of code if you would like to visualize a set of pairwise class 
### interpolations for a collection of different numbers, all in a single grid of interpolations.
### You'll also see another visualization like this in the next code block!
# plot_numbers = [2, 3, 4, 5, 7]
# n_numbers = len(plot_numbers)
# plt.figure(figsize=(8, 8))
# for i, first_plot_number in enumerate(plot_numbers):
#     for j, second_plot_number in enumerate(plot_numbers):
#         plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
#         interpolate_class(first_plot_number, second_plot_number)
#         plt.axis('off')
# plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
# plt.show()
# plt.close()

Changing the Noise Vector

Now, what happens if you hold the class constant, but instead you change the noise vector? You can also interpolate the noise vector and generate an image at each step.

n_interpolation = 9 # How many intermediate images you want + 2 (for the start and end image)

This time you're interpolating between the noise instead of the labels

interpolation_label = get_one_hot_labels(torch.Tensor([5]).long(), n_classes).repeat(n_interpolation, 1).float()
def interpolate_noise(first_noise, second_noise):
    # This time you're interpolating between the noise instead of the labels
    percent_first_noise = torch.linspace(0, 1, n_interpolation)[:, None].to(device)
    interpolation_noise = first_noise * percent_first_noise + second_noise * (1 - percent_first_noise)

    # Combine the noise and the labels again
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_label.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)

Generate noise vectors to interpolate between.

### Change me! ###
n_noise = 5 # Choose the number of noise examples in the grid
plot_noises = [get_noise(1, z_dim, device=device) for i in range(n_noise)]
plt.figure(figsize=(8, 8))
for i, first_plot_noise in enumerate(plot_noises):
    for j, second_plot_noise in enumerate(plot_noises):
        plt.subplot(n_noise, n_noise, i * n_noise + j + 1)
        interpolate_noise(first_plot_noise, second_plot_noise)
        plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()

End