Evaluating GANs
Write your post here.
Write your post here.
In this notebook, we're going to implement a GAN controllability method using gradients from a classifier. By training a classifier to recognize a relevant feature, we can use it to change the generator's inputs (z-vectors) to make it generate images with more or less of that feature.
We will be started we off with a pre-trained generator and classifier, so that we can focus on the controllability aspects.
The classifier has the same archicture as the earlier critic (remember that the discriminator/critic is simply a classifier used to classify real and fake).
Instead of the MNIST dataset, we will be using CelebA. CelebA is a dataset of annotated celebrity images. Since they are colored (not black-and-white), the images have three channels for red, green, and blue (RGB). We'll be using the pre-built pytorch Celeba dataset.
# python
from argparse import Namespace
from collections import namedtuple
from functools import partial
from pathlib import Path
import pickle
# pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid
import holoviews
import hvplot.pandas
import matplotlib.pyplot as pyplot
import pandas
import torch
# my stuff
from graeae import EmbedHoloviews, Timer
TIMER = Timer()
torch.manual_seed(0)
SLUG = "controllable-generation"
OUTPUT = f"files/posts/gans/{SLUG}/"
Embed = partial(EmbedHoloviews, folder_path=OUTPUT)
Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
width=900,
height=750,
fontscale=2,
tan="#ddb377",
blue="#4687b7",
red="#ce7b6d",
)
base_path = Path("~/models/gans/celeba/").expanduser()
assert base_path.is_dir()
prebuilt_models = Namespace(
celeba = base_path/"pretrained_celeba.pth",
classifier = base_path/"pretrained_classifier.pth"
)
data_path = Path("~/pytorch-data/").expanduser()
if not data_path.is_dir():
data_path.mkdir()
assert prebuilt_models.celeba.is_file()
assert prebuilt_models.classifier.is_file()
def save_tensor_images(image_tensor: torch.Tensor,
filename: str,
title: str,
folder: str=f"files/posts/gans{SLUG}/",
num_images: int=16, size: tuple=(1, 28, 28), nrow=3):
"""Plot an Image Tensor
Args:
image_tensor: tensor with the values for the image to plot
filename: name to save the file under
folder: path to put the file in
title: title for the image
num_images: how many images from the tensor to use
size: the dimensions for each image
"""
image_tensor = (image_tensor + 1) / 2
image_unflat = image_tensor.detach().cpu()
image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
pyplot.title(title)
pyplot.grid(False)
pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
pyplot.tick_params(bottom=False, top=False, labelbottom=False,
right=False, left=False, labelleft=False)
pyplot.savefig(folder + filename)
print(f"[[file:{filename}]]")
return
This is mostly the same as the other Generators but the images are now color so the channels are different and the model has more initial hidden nodes (and one extra hidden block).
class Generator(nn.Module):
"""Generator for the celeba images
Args:
z_dim: the dimension of the noise vector, a scalar
im_chan: the number of channels in the images, fitted for the dataset used, a scalar
(CelebA is rgb, so 3 is our default)
hidden_dim: the inner dimension, a scalar
"""
def __init__(self, z_dim: int=10, im_chan: int=3, hidden_dim: int=64):
super().__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
self.make_gen_block(z_dim, hidden_dim * 8),
self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2),
self.make_gen_block(hidden_dim * 2, hidden_dim),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels: int, output_channels: int,
kernel_size: int=3, stride: int=2,
final_layer: bool=False) -> nn.Sequential:
"""Create a sequence of operations corresponding to a generator block of DCGAN
- a transposed convolution
- a batchnorm (except in the final layer)
- an activation.
Args:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
Returns:
sequence of layers
"""
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh(),
)
def forward(self, noise: torch.Tensor) -> torch.Tensor:
"""Complete a forward pass of the generator
Args:
Parameters:
noise: a noise tensor with dimensions (n_samples, z_dim)
Returns:
generated images.
"""
x = noise.view(len(noise), self.z_dim, 1, 1)
return self.gen(x)
I still don't get this…
get_noise = torch.randn
class Classifier(nn.Module):
"""The Classifier (Discriminator)
Args:
im_chan: the number of channels in the images, fitted for the dataset used, a scalar
(CelebA is rgb, so 3 is our default)
n_classes: the total number of classes in the dataset, an integer scalar
hidden_dim: the inner dimension, a scalar
"""
def __init__(self, im_chan: int=3, n_classes: int=2, hidden_dim: int=64):
super().__init__()
self.classifier = nn.Sequential(
self.make_classifier_block(im_chan, hidden_dim),
self.make_classifier_block(hidden_dim, hidden_dim * 2),
self.make_classifier_block(hidden_dim * 2, hidden_dim * 4, stride=3),
self.make_classifier_block(hidden_dim * 4, n_classes, final_layer=True),
)
def make_classifier_block(self, input_channels: int, output_channels: int,
kernel_size: int=4, stride: int=2,
final_layer: bool=False) -> nn.Sequential:
"""Create a sequence of operations corresponding to a classifier block
- a convolution
- a batchnorm (except in the final layer)
- an activation (except in the final layer).
Args:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
Returns:
Sequence of layers
"""
if final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
else:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""Complete a forward pass of the classifier
Args:
image: a flattened image tensor with im_chan channels
Returns:
an n_classes-dimension tensor representing fake/real.
"""
class_pred = self.classifier(image)
return class_pred.view(len(class_pred), -1)
Before we begin training, we need to specify a few parameters:
z_dim = 64
batch_size = 128
device = 'cuda'
Note: the Celeba
class will sometimes raise an exception:
Traceback (most recent call last): File "/home/neurotic/download_celeba.py", line 27, in <module> CelebA(data_path, split='train', download=True, transform=transform), File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/site-packages/torchvision/datasets/celeba.py", line 77, in __init__ self.download() File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/site-packages/torchvision/datasets/celeba.py", line 131, in download with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/zipfile.py", line 1257, in __init__ self._RealGetContents() File "/home/neurotic/.conda/envs/neurotic-pytorch/lib/python3.9/zipfile.py", line 1324, in _RealGetContents raise BadZipFile("File is not a zip file") zipfile.BadZipFile: File is not a zip file
According to this bug report the problem is that the files are stored on Google Drive which has a limit on the amount of data you can download per day and if it's been exceeded then when you try to download =imgalignceleba.zip" instead of the zip file you get an HTML page (of the same name) with this message:
Sorry, you can't view or download this file at this time. Too many users have viewed or downloaded this file recently. Please try accessing the file again later. If the file you are trying to access is particularly large or is shared with many people, it may take up to 24 hours to be able to view or download the file. If you still can't access a file after 24 hours, contact your domain administrator.
The data is available on kaggle so if you download it from them and put the file where the bad file is it should work - except of course, it doesn't. It turns out that some of the text files were also replaced with warnings that the download limit was exceeded so I needed to download those as well, but the files on kaggle are formatted as comma-separated files while the original files are space-separated, but even replacing the commas with spaces won't pass the MD5 check - maybe the line endings are different too? Anyway, the images work so I just waited a day and downloaded the text files from the google drive, which seemed to fix it.
def train_classifier(filename: Path, data_path: Path, epochs: int=3,
learning_rate: float=0.001, display_step: int=500,
classifier: Classifier=None) -> list:
"""Trains the Classifier
Args:
filename: path to save the state-dict to
data_path: path to the celeba data
Returns:
classifier losses
"""
label_indices = range(40)
display_step = 500
beta_1 = 0.5
beta_2 = 0.999
image_size = 64
best_loss = float("inf")
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataloader = DataLoader(
CelebA(str(data_path), split='train', download=False, transform=transform),
batch_size=batch_size,
shuffle=True)
if classifier is None:
classifier = Classifier(n_classes=len(label_indices)).to(device)
class_opt = torch.optim.Adam(classifier.parameters(), lr=learning_rate, betas=(beta_1, beta_2))
criterion = nn.BCEWithLogitsLoss()
cur_step = 0
classifier_losses = []
# classifier_val_losses = []
for epoch in range(epochs):
for real, labels in dataloader:
real = real.to(device)
labels = labels[:, label_indices].to(device).float()
class_opt.zero_grad()
class_pred = classifier(real)
class_loss = criterion(class_pred, labels)
class_loss.backward() # Calculate the gradients
class_opt.step() # Update the weights
classifier_losses += [class_loss.item()] # Keep track of the average classifier loss
## Visualization code ##
if classifier_losses[-1] < best_loss:
torch.save({"classifier": classifier.state_dict()}, filename)
best_loss = classifier_losses[-1]
if cur_step % display_step == 0 and cur_step > 0:
class_mean = sum(classifier_losses[-display_step:]) / display_step
print(f"Step {cur_step}: Classifier loss: {class_mean}")
step_bins = 20
cur_step += 1
return classifier_losses
classifier_state_dict = Path("~/models/gans/celeba/trained_classifier.pth").expanduser()
with TIMER:
classifier_losses = train_classifier(classifier_state_dict, data_path, epochs=100)
Started: 2021-05-10 16:52:26.506156 Step 500: Classifier loss: 0.2693843246996403 Step 1000: Classifier loss: 0.24250468423962593 Step 1500: Classifier loss: 0.2307623517513275 Step 2000: Classifier loss: 0.22525288465619087 Step 2500: Classifier loss: 0.22275795283913613 Step 3000: Classifier loss: 0.2154263758957386 Step 3500: Classifier loss: 0.21474265044927596 Step 4000: Classifier loss: 0.21102887812256813 Step 4500: Classifier loss: 0.20789319404959677 Step 5000: Classifier loss: 0.20887315857410432 Step 5500: Classifier loss: 0.20212965056300164 Step 6000: Classifier loss: 0.20280044555664062 Step 6500: Classifier loss: 0.20041452285647393 Step 7000: Classifier loss: 0.19656063199043275 Step 7500: Classifier loss: 0.19845477828383445 Step 8000: Classifier loss: 0.19205777409672736 Step 8500: Classifier loss: 0.19296112078428268 Step 9000: Classifier loss: 0.19257169529795648 Step 9500: Classifier loss: 0.18672975289821625 Step 10000: Classifier loss: 0.18999777460098266 Step 10500: Classifier loss: 0.18432555946707727 Step 11000: Classifier loss: 0.18430076670646667 Step 11500: Classifier loss: 0.18558891993761062 Step 12000: Classifier loss: 0.17852741411328316 Step 12500: Classifier loss: 0.18172698724269867 Step 13000: Classifier loss: 0.1773110607266426 Step 13500: Classifier loss: 0.17672735232114792 Step 14000: Classifier loss: 0.17991324526071548 Step 14500: Classifier loss: 0.17025035387277604 Step 15000: Classifier loss: 0.17529139894247056 Step 15500: Classifier loss: 0.17104978796839715 Step 16000: Classifier loss: 0.17034502020478248 Step 16500: Classifier loss: 0.17325083956122397 Step 17000: Classifier loss: 0.1642009498178959 Step 17500: Classifier loss: 0.16845661264657974 Step 18000: Classifier loss: 0.1664019832611084 Step 18500: Classifier loss: 0.1633680825829506 Step 19000: Classifier loss: 0.16797509816288947 Step 19500: Classifier loss: 0.15925687023997306 Step 20000: Classifier loss: 0.16292004188895226 Step 20500: Classifier loss: 0.16216046965122222 Step 21000: Classifier loss: 0.15743515598773955 Step 21500: Classifier loss: 0.16243972438573837 Step 22000: Classifier loss: 0.1545857997238636 Step 22500: Classifier loss: 0.15797922548651694 Step 23000: Classifier loss: 0.15804208835959435 Step 23500: Classifier loss: 0.15213489854335785 Step 24000: Classifier loss: 0.15730918619036674 Step 24500: Classifier loss: 0.1511693196296692 Step 25000: Classifier loss: 0.1527680770754814 Step 25500: Classifier loss: 0.15510675182938577 Step 26000: Classifier loss: 0.14683012741804122 Step 26500: Classifier loss: 0.15305917632579805 Step 27000: Classifier loss: 0.14754199008643626 Step 27500: Classifier loss: 0.14820717003941536 Step 28000: Classifier loss: 0.15238315638899802 Step 28500: Classifier loss: 0.14171919177472592 Step 29000: Classifier loss: 0.14881789454817773 Step 29500: Classifier loss: 0.1449408364146948 Step 30000: Classifier loss: 0.1441956951916218 Step 30500: Classifier loss: 0.1483478535115719 Step 31000: Classifier loss: 0.13893532317876817 Step 31500: Classifier loss: 0.1450331158787012 Step 32000: Classifier loss: 0.14139907719194889 Step 32500: Classifier loss: 0.1396861730515957 Step 33000: Classifier loss: 0.1451952086240053 Step 33500: Classifier loss: 0.1358419010192156 Step 34000: Classifier loss: 0.14111693547666074 Step 34500: Classifier loss: 0.1400791739821434 Step 35000: Classifier loss: 0.1358947957903147 Step 35500: Classifier loss: 0.14151665523648263 Step 36000: Classifier loss: 0.1336766537129879 Step 36500: Classifier loss: 0.13722201707959175 Step 37000: Classifier loss: 0.1379301232844591 Step 37500: Classifier loss: 0.13219603390991688 Step 38000: Classifier loss: 0.13811730867624283 Step 38500: Classifier loss: 0.13158722695708275 Step 39000: Classifier loss: 0.13359902986884117 Step 39500: Classifier loss: 0.1366793801188469 Step 40000: Classifier loss: 0.12849617034196853 Step 40500: Classifier loss: 0.13549049003422262 Step 41000: Classifier loss: 0.12929423077404498 Step 41500: Classifier loss: 0.13080933578312398 Step 42000: Classifier loss: 0.13500430592894555 Step 42500: Classifier loss: 0.12454062223434448 Step 43000: Classifier loss: 0.13214491476118564 Step 43500: Classifier loss: 0.1284936859458685 Step 44000: Classifier loss: 0.12763021168112754 Step 44500: Classifier loss: 0.13298917169868946 Step 45000: Classifier loss: 0.12208985219895839 Step 45500: Classifier loss: 0.129048362582922 Step 46000: Classifier loss: 0.12678204217553138 Step 46500: Classifier loss: 0.12455842156708241 Step 47000: Classifier loss: 0.1303500325381756 Step 47500: Classifier loss: 0.12025414818525314 Step 48000: Classifier loss: 0.12684993542730807 Step 48500: Classifier loss: 0.1252559674978256 Step 49000: Classifier loss: 0.12153738121688366 Step 49500: Classifier loss: 0.12777481034398078 Step 50000: Classifier loss: 0.118936713129282 Step 50500: Classifier loss: 0.12405500474572181 Ended: 2021-05-10 18:56:36.980805 Elapsed: 2:04:10.474649
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss")()
print(output)
n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(classifier_state_dict, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()
with TIMER:
classifier_losses = train_classifier(classifier_state_dict, data_path,
epochs=40,
classifier=classifier)
Started: 2021-05-11 16:02:16.181203 Step 500: Classifier loss: 0.1181784438341856 Step 1000: Classifier loss: 0.12448641647398472 Step 1500: Classifier loss: 0.1214247584193945 Step 2000: Classifier loss: 0.1198666417747736 Step 2500: Classifier loss: 0.1255625690817833 Step 3000: Classifier loss: 0.11589906251430511 Step 3500: Classifier loss: 0.12224359685182572 Step 4000: Classifier loss: 0.11944249965250492 Step 4500: Classifier loss: 0.1175859476029873 Step 5000: Classifier loss: 0.12318077574670315 Step 5500: Classifier loss: 0.11450052106380462 Step 6000: Classifier loss: 0.11944048409163951 Step 6500: Classifier loss: 0.11928777326643467 Step 7000: Classifier loss: 0.11463624723255635 Step 7500: Classifier loss: 0.12107200682163238 Step 8000: Classifier loss: 0.11355004295706748 Step 8500: Classifier loss: 0.11719673483073711 Step 9000: Classifier loss: 0.11821492326259612 Step 9500: Classifier loss: 0.11198448015749454 Step 10000: Classifier loss: 0.11870198084414005 Step 10500: Classifier loss: 0.11221958647668362 Step 11000: Classifier loss: 0.11476752410829068 Step 11500: Classifier loss: 0.11772396117448806 Step 12000: Classifier loss: 0.10936097744107247 Step 12500: Classifier loss: 0.11677812692523003 Step 13000: Classifier loss: 0.11107682411372662 Step 13500: Classifier loss: 0.11222303664684295 Step 14000: Classifier loss: 0.11760448211431504 Step 14500: Classifier loss: 0.10662877394258977 Step 15000: Classifier loss: 0.11471863305568696 Step 15500: Classifier loss: 0.11056565625965595 Step 16000: Classifier loss: 0.11046012189984322 Step 16500: Classifier loss: 0.1158019468486309 Step 17000: Classifier loss: 0.10568901741504669 Step 17500: Classifier loss: 0.11223984396457672 Step 18000: Classifier loss: 0.11002579489350318 Step 18500: Classifier loss: 0.10752195838093757 Step 19000: Classifier loss: 0.11419818633794784 Step 19500: Classifier loss: 0.10464896529912948 Step 20000: Classifier loss: 0.11005591739714146 Step 20500: Classifier loss: 0.10996675519645215 Step 21000: Classifier loss: 0.10543355357646943 Step 21500: Classifier loss: 0.11205300988256932 Step 22000: Classifier loss: 0.1038715885579586 Step 22500: Classifier loss: 0.10818033437430859 Step 23000: Classifier loss: 0.10912492156028747 Step 23500: Classifier loss: 0.10302072758972645 Step 24000: Classifier loss: 0.11008756360411644 Step 24500: Classifier loss: 0.10342664630711079 Step 25000: Classifier loss: 0.10618587562441825 Step 25500: Classifier loss: 0.10913233712315559 Step 26000: Classifier loss: 0.10061963592469693 Step 26500: Classifier loss: 0.10828037586808205 Step 27000: Classifier loss: 0.10266246040165425 Step 27500: Classifier loss: 0.1047897623181343 Step 28000: Classifier loss: 0.10866250747442245 Step 28500: Classifier loss: 0.09820086953043938 Step 29000: Classifier loss: 0.10674160474538803 Step 29500: Classifier loss: 0.10230921612679958 Step 30000: Classifier loss: 0.1021555609256029 Step 30500: Classifier loss: 0.10775842162966728 Step 31000: Classifier loss: 0.09722121758759021 Step 31500: Classifier loss: 0.10439497400820255 Step 32000: Classifier loss: 0.10229390095174312 Step 32500: Classifier loss: 0.10003190772235393 Step 33000: Classifier loss: 0.10617333140969276 Step 33500: Classifier loss: 0.09686395044624806 Step 34000: Classifier loss: 0.10285020883381367 Step 34500: Classifier loss: 0.10199978332221508 Step 35000: Classifier loss: 0.09819360673427582 Step 35500: Classifier loss: 0.10397693109512329 Step 36000: Classifier loss: 0.09642438031733036 Step 36500: Classifier loss: 0.10087257397174836 Step 37000: Classifier loss: 0.10197833214700222 Step 37500: Classifier loss: 0.09598418261110783 Step 38000: Classifier loss: 0.10283542364835739 Step 38500: Classifier loss: 0.09644483177363873 Step 39000: Classifier loss: 0.09908602401614189 Step 39500: Classifier loss: 0.10129908196628094 Step 40000: Classifier loss: 0.0939527989178896 Step 40500: Classifier loss: 0.1016722819507122 Step 41000: Classifier loss: 0.09578396078944207 Step 41500: Classifier loss: 0.09706279496848583 Step 42000: Classifier loss: 0.10207961940765381 Step 42500: Classifier loss: 0.09211373472213745 Step 43000: Classifier loss: 0.09958744782209396 Step 43500: Classifier loss: 0.09534277887642384 Step 44000: Classifier loss: 0.0952163600474596 Step 44500: Classifier loss: 0.10136887782812118 Step 45000: Classifier loss: 0.09021547995507717 Step 45500: Classifier loss: 0.09812712541222572 Step 46000: Classifier loss: 0.09560927426815033 Step 46500: Classifier loss: 0.09358323478698731 Step 47000: Classifier loss: 0.09991893386840821 Step 47500: Classifier loss: 0.0899157041311264 Step 48000: Classifier loss: 0.096542285323143 Step 48500: Classifier loss: 0.09535252919793129 Step 49000: Classifier loss: 0.09194727616012097 Step 49500: Classifier loss: 0.09831891848146915 Step 50000: Classifier loss: 0.0901611197590828 Step 50500: Classifier loss: 0.09490065774321556 Ended: 2021-05-11 18:06:19.399986 Elapsed: 2:04:03.218783
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss Session 2", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss_2")()
print(output)
n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(classifier_state_dict, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()
with TIMER:
classifier_losses = train_classifier(classifier_state_dict, data_path,
epochs=40,
classifier=classifier)
Started: 2021-05-11 21:11:32.420506 Step 500: Classifier loss: 0.09006546361744404 Step 1000: Classifier loss: 0.09647199404239655 Step 1500: Classifier loss: 0.092734768897295 Step 2000: Classifier loss: 0.09196118661761284 Step 2500: Classifier loss: 0.09789373110234738 Step 3000: Classifier loss: 0.08788785541057587 Step 3500: Classifier loss: 0.0945415845811367 Step 4000: Classifier loss: 0.09308994428813458 Step 4500: Classifier loss: 0.09022623193264008 Step 5000: Classifier loss: 0.09615834753215313 Step 5500: Classifier loss: 0.08742603194713593 Step 6000: Classifier loss: 0.09316775412857532 Step 6500: Classifier loss: 0.09275233449041843 Step 7000: Classifier loss: 0.08822614887356758 Step 7500: Classifier loss: 0.09528714890778064 Step 8000: Classifier loss: 0.08681275172531605 Step 8500: Classifier loss: 0.09150236696004868 Step 9000: Classifier loss: 0.09338522186875343 Step 9500: Classifier loss: 0.08638478130102158 Step 10000: Classifier loss: 0.09388372772932052 Step 10500: Classifier loss: 0.08720742921531201 Step 11000: Classifier loss: 0.09009483934938908 Step 11500: Classifier loss: 0.0929495030939579 Step 12000: Classifier loss: 0.08460890363156795 Step 12500: Classifier loss: 0.0924714410007 Step 13000: Classifier loss: 0.08704712373018265 Step 13500: Classifier loss: 0.08819058662652969 Step 14000: Classifier loss: 0.09366303083300591 Step 14500: Classifier loss: 0.08295501434803008 Step 15000: Classifier loss: 0.09084490737318993 Step 15500: Classifier loss: 0.08707242746651173 Step 16000: Classifier loss: 0.08690852355957031 Step 16500: Classifier loss: 0.09254233407974242 Step 17000: Classifier loss: 0.08242024271190167 Step 17500: Classifier loss: 0.08904271678626538 Step 18000: Classifier loss: 0.08771026766300201 Step 18500: Classifier loss: 0.08471861970424652 Step 19000: Classifier loss: 0.09134728060662746 Step 19500: Classifier loss: 0.08233513435721397 Step 20000: Classifier loss: 0.08778411850333213 Step 20500: Classifier loss: 0.08791485584527255 Step 21000: Classifier loss: 0.08345357306301594 Step 21500: Classifier loss: 0.08975999920070171 Step 22000: Classifier loss: 0.08225472408533097 Step 22500: Classifier loss: 0.08668080273270606 Step 23000: Classifier loss: 0.08786206224560737 Step 23500: Classifier loss: 0.08155409483611584 Step 24000: Classifier loss: 0.08907847443222999 Step 24500: Classifier loss: 0.08202618369460106 Step 25000: Classifier loss: 0.08517973597347736 Step 25500: Classifier loss: 0.08817093770205975 Step 26000: Classifier loss: 0.08008052316308022 Step 26500: Classifier loss: 0.08741954331099987 Step 27000: Classifier loss: 0.08247932478785515 Step 27500: Classifier loss: 0.08377225384116173 Step 28000: Classifier loss: 0.08846944206953049 Step 28500: Classifier loss: 0.07859189368784428 Step 29000: Classifier loss: 0.08617163190245629 Step 29500: Classifier loss: 0.0824531610161066 Step 30000: Classifier loss: 0.08195052224397659 Step 30500: Classifier loss: 0.08803890940546989 Step 31000: Classifier loss: 0.07793828934431075 Step 31500: Classifier loss: 0.08464510484039783 Step 32000: Classifier loss: 0.08275749842077494 Step 32500: Classifier loss: 0.0805082704871893 Step 33000: Classifier loss: 0.08703124921023846 Step 33500: Classifier loss: 0.0772736611738801 Step 34000: Classifier loss: 0.08353734220564366 Step 34500: Classifier loss: 0.08343685203790664 Step 35000: Classifier loss: 0.07905932680517436 Step 35500: Classifier loss: 0.08568261863291264 Step 36000: Classifier loss: 0.07762402860075235 Step 36500: Classifier loss: 0.08223582464456558 Step 37000: Classifier loss: 0.08341778349876404 Step 37500: Classifier loss: 0.07801838412880897 Step 38000: Classifier loss: 0.0842266542762518 Step 38500: Classifier loss: 0.07764634099602699 Step 39000: Classifier loss: 0.08104524739086628 Step 39500: Classifier loss: 0.08389902476221323 Step 40000: Classifier loss: 0.07612183248996734 Step 40500: Classifier loss: 0.08296740844845772 Step 41000: Classifier loss: 0.0781253460124135 Step 41500: Classifier loss: 0.07980525248497725 Step 42000: Classifier loss: 0.08405549557507039 Step 42500: Classifier loss: 0.0743530157059431 Step 43000: Classifier loss: 0.08219673927128315 Step 43500: Classifier loss: 0.07845095673948527 Step 44000: Classifier loss: 0.07780187250673772 Step 44500: Classifier loss: 0.08399353076517582 Step 45000: Classifier loss: 0.07365029990673065 Step 45500: Classifier loss: 0.0808380290567875 Step 46000: Classifier loss: 0.0786423703506589 Step 46500: Classifier loss: 0.07693013155460357 Step 47000: Classifier loss: 0.08244228959083558 Step 47500: Classifier loss: 0.07365631985664367 Step 48000: Classifier loss: 0.07970952866971492 Step 48500: Classifier loss: 0.07868134459108114 Step 49000: Classifier loss: 0.07539512529224157 Step 49500: Classifier loss: 0.08191524033248425 Step 50000: Classifier loss: 0.07361406400799751 Step 50500: Classifier loss: 0.07847459720075131 Ended: 2021-05-11 23:15:24.444278 Elapsed: 2:03:52.023772
losses = pandas.DataFrame.from_dict(dict(Loss=classifier_losses))
plot = losses.hvplot(y="Loss", title="Classifier Loss Session 3", color=PLOT.tan).opts(width=PLOT.width, height=PLOT.height)
output = Embed(plot=plot, file_name="classifier_loss_3")()
print(output)
We will then load the pretrained generator and classifier using the following code. (If we trained our own classifier, we can load that one here instead.)
import torch
gen = Generator(z_dim).to(device)
gen_dict = torch.load(prebuilt_models.celeba, map_location=torch.device(device))["gen"]
gen.load_state_dict(gen_dict)
gen.eval()
n_classes = 40
classifier = Classifier(n_classes=n_classes).to(device)
class_dict = torch.load(prebuilt_models.classifier, map_location=torch.device(device))["classifier"]
classifier.load_state_dict(class_dict)
classifier.eval()
opt = torch.optim.Adam(classifier.parameters(), lr=0.01)
Now we can start implementing a method for controlling our GAN.
For training, we need to write the code to update the noise to produce more of our desired feature. We do this by performing stochastic gradient ascent. We use stochastic gradient ascent to find the local maxima, as opposed to stochastic gradient descent which finds the local minima. Gradient ascent is gradient descent over the negative of the value being optimized. Their formulas are essentially the same, however, instead of subtracting the weighted value, stochastic gradient ascent adds it; it can be calculated by \(new = old + (∇ old * weight)\), where ∇ is the gradient of old
. We perform stochastic gradient ascent to try and maximize the amount of the feature we want. If we wanted to reduce the amount of the feature, we would perform gradient descent. However, in this assignment we are interested in maximize our feature using gradient ascent, since many features in the dataset are not present much more often than they're present and we are trying to add a feature to the images, not remove.
Given the noise with its gradient already calculated through the classifier, we want to return the new noise vector.
def calculate_updated_noise(noise: torch.Tensor, weight: float) -> torch.Tensor:
"""Update noise vectors with stochastic gradient ascent.
Args:
noise: the current noise vectors.
We have already called the backwards function on the target class
so we can access the gradient of the output class with respect
to the noise by using noise.grad
weight: the scalar amount by which we should weight the noise gradient
Returns:
updated noise
"""
new_noise = noise + (noise.grad * weight)
return new_noise
Check that the basic function works.
opt.zero_grad()
noise = torch.ones(20, 20) * 2
noise.requires_grad_()
fake_classes = (noise ** 2).mean()
fake_classes.backward()
new_noise = calculate_updated_noise(noise, 0.1)
assert type(new_noise) == torch.Tensor
assert tuple(new_noise.shape) == (20, 20)
assert new_noise.max() == 2.0010
assert new_noise.min() == 2.0010
assert torch.isclose(new_noise.sum(), torch.tensor(0.4) + 20 * 20 * 2)
Check that it works for generated images
opt.zero_grad()
noise = get_noise(32, z_dim).to(device).requires_grad_()
fake = gen(noise)
fake_classes = classifier(fake)[:, 0]
fake_classes.mean().backward()
noise.data = calculate_updated_noise(noise, 0.01)
fake = gen(noise)
fake_classes_new = classifier(fake)[:, 0]
assert torch.all(fake_classes_new > fake_classes)
Now, we can use the classifier along with stochastic gradient ascent to make noise that generates more of a certain feature. In the code given to us here, we can generate smiling faces. Feel free to change the target index and control some of the other features in the list! We will notice that some features are easier to detect and control than others.
The list we have here are the features labeled in CelebA, which we used to train our classifier. If we wanted to control another feature, we would need to get data that is labeled with that feature and train a classifier on that feature.
First generate a bunch of images with the generator.
n_images = 8
fake_image_history = []
grad_steps = 10 # Number of gradient steps to take
skip = 2 # Number of gradient steps to skip in the visualization
feature_names = ["5oClockShadow", "ArchedEyebrows", "Attractive", "BagsUnderEyes", "Bald", "Bangs",
"BigLips", "BigNose", "BlackHair", "BlondHair", "Blurry", "BrownHair", "BushyEyebrows", "Chubby",
"DoubleChin", "Eyeglasses", "Goatee", "GrayHair", "HeavyMakeup", "HighCheekbones", "Male",
"MouthSlightlyOpen", "Mustache", "NarrowEyes", "NoBeard", "OvalFace", "PaleSkin", "PointyNose",
"RecedingHairline", "RosyCheeks", "Sideburn", "Smiling", "StraightHair", "WavyHair", "WearingEarrings",
"WearingHat", "WearingLipstick", "WearingNecklace", "WearingNecktie", "Weng"]
### Change me! ###
target_indices = feature_names.index("Weng") # Feel free to change this value to any string from feature_names!
noise = get_noise(n_images, z_dim).to(device).requires_grad_()
for i in range(grad_steps):
opt.zero_grad()
fake = gen(noise)
fake_image_history += [fake]
fake_classes_score = classifier(fake)[:, target_indices].mean()
fake_classes_score.backward()
noise.data = calculate_updated_noise(noise, 1 / grad_steps)
pyplot.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
save_tensor_images(image_tensor=torch.cat(fake_image_history[::skip], dim=2),
filename="weng.png", folder=OUTPUT, title="Weng",
num_images=n_images, nrow=n_images)
We may also notice that sometimes more features than just the target feature change. This is because some features are entangled. To fix this, we can try to isolate the target feature more by holding the classes outside of the target class constant. One way we can implement this is by penalizing the differences from the original class with L2 regularization. This L2 regularization would apply a penalty for this difference using the L2 norm and this would just be an additional term on the loss function.
Here, we'll have to implement the score function: the higher, the better. The score is calculated by adding the target score and a penalty – note that the penalty is meant to lower the score, so it should have a negative value.
For every non-target class, take the difference between the current noise and the old noise. The greater this value is, the more features outside the target have changed. We will calculate the magnitude of the change, take the mean, and negate it. Finally, add this penalty to the target score. The target score is the mean of the target class in the current noise.
torch.norm
.Note: torch.norm
is deprecated, they want you to use torch.linalg.norm instead.
def get_score(current_classifications: torch.Tensor,
original_classifications: torch.Tensor,
target_indices: torch.Tensor,
other_indices: torch.Tensor,
penalty_weight: float) -> torch.Tensor:
"""Score the current classifications, L2 Norm penalty
Args:
current_classifications: the classifications associated with the current noise
original_classifications: the classifications associated with the original noise
target_indices: the index of the target class
other_indices: the indices of the other classes
penalty_weight: the amount that the penalty should be weighted in the overall score
Returns:
the score of the current classification with L2 Norm penalty
"""
# Steps: 1) Calculate the change between the original and current classifications (as a tensor)
# by indexing into the other_indices we're trying to preserve, like in x[:, features].
# 2) Calculate the norm (magnitude) of changes per example.
# 3) Multiply the mean of the example norms by the penalty weight.
# This will be our other_class_penalty.
# Make sure to negate the value since it's a penalty!
# 4) Take the mean of the current classifications for the target feature over all the examples.
# This mean will be our target_score.
# Calculate the norm (magnitude) of changes per example and multiply by penalty weight
other_class_penalty = -(torch.mean(
torch.linalg.norm(original_classifications[:, other_indices]
- current_classifications[:, other_indices], dim=1))
* penalty_weight)
# Take the mean of the current classifications for the target feature
target_score = torch.mean(current_classifications[:, target_indices])
return target_score + other_class_penalty
assert torch.isclose(
get_score(torch.ones(4, 3), torch.zeros(4, 3), [0], [1, 2], 0.2),
1 - torch.sqrt(torch.tensor(2.)) * 0.2
)
rows = 10
current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
# Must be 3
assert get_score(current_class, original_class, [1, 3] , [0, 2], 0.2).item() == 3
current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[4] * rows, [4] * rows, [2] * rows, [1] * rows]).T.float()
# Must be 3 - 0.2 * sqrt(10)
assert torch.isclose(get_score(current_class, original_class, [1, 3] , [0, 2], 0.2),
-torch.sqrt(torch.tensor(10.0)) * 0.2 + 3)
In the following block of code, we will run the gradient ascent with this new score function. We might notice a few things after running it:
Alternatively, even if the generator can produce an image with the intended features, it might require many intermediate changes to get there and may get stuck in a local minimum.
fake_image_history = []
### Change me! ###
target_indices = feature_names.index("Goatee") # Feel free to change this value to any string from feature_names from earlier!
other_indices = [cur_idx != target_indices for cur_idx, _ in enumerate(feature_names)]
noise = get_noise(n_images, z_dim).to(device).requires_grad_()
original_classifications = classifier(gen(noise)).detach()
for i in range(grad_steps):
opt.zero_grad()
fake = gen(noise)
fake_image_history += [fake]
fake_score = get_score(
classifier(fake),
original_classifications,
target_indices,
other_indices,
penalty_weight=0.1
)
fake_score.backward()
noise.data = calculate_updated_noise(noise, 1 / grad_steps)
pyplot.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
save_tensor_images(torch.cat(fake_image_history[::skip], dim=2), num_images=n_images, nrow=n_images, filename="goatee.png", folder=OUTPUT, title="Goatee")
# python standard library
from pathlib import Path
import math
# pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
import matplotlib.pyplot as pyplot
import torch
import torch.nn.functional as F
# my stuff
from graeae import Timer
TIMER = Timer()
torch.manual_seed(0)
SLUG = "a-conditional-gan"
def save_tensor_images(image_tensor: torch.Tensor,
filename: str,
title: str,
folder: str=f"files/posts/gans{SLUG}",
num_images: int=25, size: tuple=(1, 28, 28)):
"""Plot an Image Tensor
Args:
image_tensor: tensor with the values for the image to plot
filename: name to save the file under
folder: path to put the file in
title: title for the image
num_images: how many images from the tensor to use
size: the dimensions for each image
"""
image_tensor = (image_tensor + 1) / 2
image_unflat = image_tensor.detach().cpu()
image_grid = make_grid(image_unflat[:num_images], nrow=5)
pyplot.title(title)
pyplot.grid(False)
pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
pyplot.tick_params(bottom=False, top=False, labelbottom=False,
right=False, left=False, labelleft=False)
pyplot.savefig(folder + filename)
print(f"[[file:{filename}]]")
return
def make_some_noise(n_samples: int, z_dim: int, device: str='cpu') -> torch.Tensor:
"""Alias for torch.randn
Args:
n_samples: the number of samples to generate
z_dim: the dimension of the noise vector
device: the device type
Returns:
tensor with random numbers from the normal distribution.
"""
return torch.randn(n_samples, z_dim, device=device)
The Generator and Discriminator are the same ones we used before except the z_dim
attribute has been renamed input_dim
to reflect the fact that the data is going to be augmented with the classification information.
class Generator(nn.Module):
"""The DCGAN Generator
Args:
input_dim: the dimension of the input vector
im_chan: the number of channels in the images, fitted for the dataset used
(MNIST is black-and-white, so 1 channel is your default)
hidden_dim: the inner dimension,
"""
def __init__(self, input_dim: int=10, im_chan: int=1, hidden_dim: int=64):
super().__init__()
self.input_dim = input_dim
self.gen = nn.Sequential(
self.make_gen_block(input_dim, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 2, hidden_dim),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels: int, output_channels: int,
kernel_size: int=3, stride: int=2,
final_layer: bool=False) -> nn.Sequential:
"""Creates a block for the generator (sub sequence)
The parts
- a transposed convolution
- a batchnorm (except for in the last layer)
- an activation.
Args:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
Returns:
the sub-sequence of layers
"""
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh(),
)
def forward(self, noise: torch.Tensor) -> torch.Tensor:
"""complete a forward pass of the generator: Given a noise tensor,
Args:
noise: a noise tensor with dimensions (n_samples, z_dim)
Returns:
generated images.
"""
# unsqueeze the noise
x = noise.view(len(noise), self.input_dim, 1, 1)
return self.gen(x)
This differs a little from the DCGAN Discriminator in that the initial hidden dimension output goes up to 64 nodes from 16 in the original.
class Discriminator(nn.Module):
"""The DCGAN Discriminator
Args:
im_chan: the number of channels in the images, fitted for the dataset used
(MNIST is black-and-white, so 1 channel is the default)
hidden_dim: the inner dimension,
"""
def __init__(self, im_chan: int=1, hidden_dim: int=64):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
self.make_disc_block(im_chan, hidden_dim),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
)
return
def make_disc_block(self, input_channels: int, output_channels: int,
kernel_size: int=4, stride: int=2,
final_layer: bool=False) -> nn.Sequential:
"""Make a sub-block of layers for the discriminator
- a convolution
- a batchnorm (except for in the last layer)
- an activation.
Args:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: if true it is the final layer and otherwise not
(affects activation and batchnorm)
"""
# Build the neural block
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2)
)
else: # Final Layer
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""Complete a forward pass of the discriminator
Args:
image: a flattened image tensor with dimension (im_dim)
Returns:
a 1-dimension tensor representing fake/real.
"""
disc_pred = self.disc(image)
return disc_pred.view(len(disc_pred), -1)
In conditional GANs, the input vector for the generator will also need to include the class information. The class is represented using a one-hot encoded vector where its length is the number of classes and each index represents a class. The vector is all 0's and a 1 on the chosen class. Given the labels of multiple images (e.g. from a batch) and number of classes, please create one-hot vectors for each label. There is a class within the PyTorch functional library that can help you.
def get_one_hot_labels(labels: torch.tensor, n_classes: int) -> torch.Tensor:
"""Create one-hot vectors for the labels
Args:
labels: tensor of labels from the dataloader
n_classes: the total number of classes in the dataset
Returns:
a tensor of shape (labels size, num_classes).
"""
#### START CODE HERE ####
return F.one_hot(labels, n_classes)
#### END CODE HERE ####
assert (
get_one_hot_labels(
labels=torch.Tensor([[0, 2, 1]]).long(),
n_classes=3
).tolist() ==
[[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]
]]
)
Next, you need to be able to concatenate the one-hot class vector to the noise vector before giving it to the generator. You will also need to do this when adding the class channels to the discriminator.
dim
argument of torch.cat
does)def combine_vectors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Combine two vectors with shapes (n_samples, ?) and (n_samples, ?).
Args:
x: the first vector.
y: the second vector.
"""
# Note: Make sure this function outputs a float no matter what inputs it receives
#### START CODE HERE ####
combined = torch.cat((x, y), 1).to(torch.float32)
#### END CODE HERE ####
return combined
combined = combine_vectors(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]]));
# Check exact order of elements
assert torch.all(combined == torch.tensor([[1, 2, 5, 6], [3, 4, 7, 8]]))
# Tests that items are of float type
assert (type(combined[0][0].item()) == float)
# Check shapes
combined = combine_vectors(torch.randn(1, 4, 5), torch.randn(1, 8, 5));
assert tuple(combined.shape) == (1, 12, 5)
assert tuple(combine_vectors(torch.randn(1, 10, 12).long(), torch.randn(1, 20, 12).long()).shape) == (1, 30, 12)
First, you will define some new parameters:
mnist_shape = (1, 28, 28)
n_classes = 10
And you also include the same parameters from before:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
data_path = Path("~/pytorch-data/MNIST/").expanduser()
dataloader = DataLoader(
MNIST(data_path, download=False, transform=transform),
batch_size=batch_size,
shuffle=True)
Then, you can initialize your generator, discriminator, and optimizers. To do this, you will need to update the input dimensions for both models. For the generator, you will need to calculate the size of the input vector; recall that for conditional GANs, the generator's input is the noise vector concatenated with the class vector. For the discriminator, you need to add a channel for every class.
def get_input_dimensions(z_dim: int, mnist_shape: tuple, n_classes: int) -> tuple:
"""Calculates the size of the conditional input dimensions
Args:
z_dim: the dimension of the noise vector
mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
n_classes: the total number of classes in the dataset, an integer scalar
(10 for MNIST)
Returns:
generator_input_dim: the input dimensionality of the conditional generator,
which takes the noise and class vectors
discriminator_im_chan: the number of input channels to the discriminator
(e.g. C x 28 x 28 for MNIST)
"""
#### START CODE HERE ####
generator_input_dim = z_dim + n_classes
discriminator_im_chan = n_classes + mnist_shape[0]
#### END CODE HERE ####
return generator_input_dim, discriminator_im_chan
def test_input_dims():
gen_dim, disc_dim = get_input_dimensions(23, (12, 23, 52), 9)
assert gen_dim == 32
assert disc_dim == 21
test_input_dims()
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)
gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
def weights_init(m) -> None:
"""Initialize the weights from the normal distribution
Args:
m: object to initialize
"""
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
return
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)
Now to train, you would like both your generator and your discriminator to know what class of image should be generated.
For example, if you're generating a picture of the number "1", you would need to:
cur_step = 0
generator_losses = []
discriminator_losses = []
noise_and_labels = False
fake = False
fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False
with TIMER:
for epoch in range(n_epochs):
# Dataloader returns the batches and the labels
for real, labels in dataloader:
cur_batch_size = len(real)
# Flatten the batch of real images from the dataset
real = real.to(device)
one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])
### Update discriminator ###
# Zero out the discriminator gradients
disc_opt.zero_grad()
# Get noise corresponding to the current batch_size
fake_noise = make_some_noise(cur_batch_size, z_dim, device=device)
# Now you can get the images from the generator
# Steps: 1) Combine the noise vectors and the one-hot labels for the generator
# 2) Generate the conditioned fake images
#### START CODE HERE ####
noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
fake = gen(noise_and_labels)
#### END CODE HERE ####
# Make sure that enough images were generated
assert len(fake) == len(real)
# Check that correct tensors were combined
assert tuple(noise_and_labels.shape) == (cur_batch_size, fake_noise.shape[1] + one_hot_labels.shape[1])
# It comes from the correct generator
assert tuple(fake.shape) == (len(real), 1, 28, 28)
# Now you can get the predictions from the discriminator
# Steps: 1) Create the input for the discriminator
# a) Combine the fake images with image_one_hot_labels,
# remember to detach the generator (.detach()) so you do not backpropagate through it
# b) Combine the real images with image_one_hot_labels
# 2) Get the discriminator's prediction on the fakes as disc_fake_pred
# 3) Get the discriminator's prediction on the reals as disc_real_pred
#### START CODE HERE ####
fake_image_and_labels = combine_vectors(fake.detach(), image_one_hot_labels)
real_image_and_labels = combine_vectors(real, image_one_hot_labels)
disc_fake_pred = disc(fake_image_and_labels)
disc_real_pred = disc(real_image_and_labels)
#### END CODE HERE ####
# Make sure shapes are correct
assert tuple(fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], 28 ,28)
assert tuple(real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1], 28 ,28)
# Make sure that enough predictions were made
assert len(disc_real_pred) == len(real)
# Make sure that the inputs are different
assert torch.any(fake_image_and_labels != real_image_and_labels)
# Shapes must match
assert tuple(fake_image_and_labels.shape) == tuple(real_image_and_labels.shape)
assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)
disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
disc_loss.backward(retain_graph=True)
disc_opt.step()
# Keep track of the average discriminator loss
discriminator_losses += [disc_loss.item()]
### Update generator ###
# Zero out the generator gradients
gen_opt.zero_grad()
fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
# This will error if you didn't concatenate your labels to your image correctly
disc_fake_pred = disc(fake_image_and_labels)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
gen_loss.backward()
gen_opt.step()
# Keep track of the generator losses
generator_losses += [gen_loss.item()]
#
if cur_step % display_step == 0 and cur_step > 0:
gen_mean = sum(generator_losses[-display_step:]) / display_step
disc_mean = sum(discriminator_losses[-display_step:]) / display_step
print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
# show_tensor_images(fake)
# show_tensor_images(real)
step_bins = 20
x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
# num_examples = (len(generator_losses) // step_bins) * step_bins
# plt.plot(
# range(num_examples // step_bins),
# torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
# label="Generator Loss"
# )
# plt.plot(
# range(num_examples // step_bins),
# torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
# label="Discriminator Loss"
# )
# plt.legend()
# plt.show()
elif cur_step == 0:
print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
cur_step += 1
Started: 2021-04-30 17:21:38.697176 Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration! Step 500: Generator loss: 2.2972163581848144, discriminator loss: 0.24098993314430117 Step 1000: Generator loss: 4.111798384666443, discriminator loss: 0.03910111421905458 Step 1500: Generator loss: 5.055200936317444, discriminator loss: 0.017988519712351263 Step 2000: Generator loss: 4.356978572845459, discriminator loss: 0.07956613468006253 Step 2500: Generator loss: 3.1875410358905794, discriminator loss: 0.1825093053430319 Step 3000: Generator loss: 2.7038124163150785, discriminator loss: 0.26903335136175155 Step 3500: Generator loss: 2.3201852326393126, discriminator loss: 0.30360834433138373 Step 4000: Generator loss: 2.3750923416614533, discriminator loss: 0.3380578280091286 Step 4500: Generator loss: 1.9010865423679353, discriminator loss: 0.3774027600288391 Step 5000: Generator loss: 1.9102082657814026, discriminator loss: 0.39759708976745606 Step 5500: Generator loss: 1.6987447377443314, discriminator loss: 0.43035421246290206 Step 6000: Generator loss: 1.6317225174903869, discriminator loss: 0.44889184486865996 Step 6500: Generator loss: 1.4883701887130738, discriminator loss: 0.47211092388629916 Step 7000: Generator loss: 1.4601867563724518, discriminator loss: 0.49546391534805295 Step 7500: Generator loss: 1.3467793072462082, discriminator loss: 0.520910717189312 Step 8000: Generator loss: 1.3130292971134185, discriminator loss: 0.5447795498371124 Step 8500: Generator loss: 1.1705783107280732, discriminator loss: 0.5716251953244209 Step 9000: Generator loss: 1.119933351278305, discriminator loss: 0.594505407333374 Step 9500: Generator loss: 1.0671374444961548, discriminator loss: 0.6109852703809738 Step 10000: Generator loss: 1.042488064646721, discriminator loss: 0.6221774860620498 Step 10500: Generator loss: 1.015932119846344, discriminator loss: 0.6356924445033073 Step 11000: Generator loss: 0.9900961836576462, discriminator loss: 0.6363092860579491 Step 11500: Generator loss: 0.9709519152641296, discriminator loss: 0.642409072637558 Step 12000: Generator loss: 0.9450125323534012, discriminator loss: 0.6527952731847763 Step 12500: Generator loss: 0.9660063650608063, discriminator loss: 0.6581431583166123 Step 13000: Generator loss: 0.8976931695938111, discriminator loss: 0.6624037518501281 Step 13500: Generator loss: 0.8969441390037537, discriminator loss: 0.6631241027116775 Step 14000: Generator loss: 0.902042475938797, discriminator loss: 0.6669842832088471 Step 14500: Generator loss: 0.8777725909948348, discriminator loss: 0.6742114140987396 Step 15000: Generator loss: 0.8425910116434098, discriminator loss: 0.6705276243686676 Step 15500: Generator loss: 0.8403865963220596, discriminator loss: 0.6768578908443451 Step 16000: Generator loss: 0.8556022493839264, discriminator loss: 0.6762306448221207 Step 16500: Generator loss: 0.8764977214336396, discriminator loss: 0.678820540189743 Step 17000: Generator loss: 0.8166215040683746, discriminator loss: 0.6793025200366973 Step 17500: Generator loss: 0.8172620774507523, discriminator loss: 0.6820640956163406 Step 18000: Generator loss: 0.8534817943572998, discriminator loss: 0.679261458158493 Step 18500: Generator loss: 0.814961371421814, discriminator loss: 0.6819399018287658 Step 19000: Generator loss: 0.8046633821725845, discriminator loss: 0.6841713408231735 Step 19500: Generator loss: 0.8040647978782653, discriminator loss: 0.6816601184606552 Step 20000: Generator loss: 0.8188033536672592, discriminator loss: 0.6830086879730225 Step 20500: Generator loss: 0.8088009203672409, discriminator loss: 0.6811848978996277 Step 21000: Generator loss: 0.7853847659826279, discriminator loss: 0.6836997301578521 Step 21500: Generator loss: 0.7798927721977233, discriminator loss: 0.684486163020134 Step 22000: Generator loss: 0.7833593552112579, discriminator loss: 0.685326477766037 Step 22500: Generator loss: 0.8140243920087814, discriminator loss: 0.6838948290348053 Step 23000: Generator loss: 0.7766883851289749, discriminator loss: 0.6896610424518586 Step 23500: Generator loss: 0.7677333387136459, discriminator loss: 0.6847020034790039 Step 24000: Generator loss: 0.7968860776424408, discriminator loss: 0.6859219465255737 Step 24500: Generator loss: 0.7655997285842896, discriminator loss: 0.6866924543380737 Step 25000: Generator loss: 0.7897603868246078, discriminator loss: 0.6862760508060455 Step 25500: Generator loss: 0.7602986326217651, discriminator loss: 0.6839702378511429 Step 26000: Generator loss: 0.7566630525588989, discriminator loss: 0.6884172073602677 Step 26500: Generator loss: 0.7693239089250564, discriminator loss: 0.6848342741727829 Step 27000: Generator loss: 0.7744117819070816, discriminator loss: 0.6884717727899552 Step 27500: Generator loss: 0.7754857275485992, discriminator loss: 0.6871565765142441 Step 28000: Generator loss: 0.7671164673566818, discriminator loss: 0.691150808095932 Step 28500: Generator loss: 0.7767693866491318, discriminator loss: 0.6899988718032837 Step 29000: Generator loss: 0.7584288560152054, discriminator loss: 0.6866833527088165 Step 29500: Generator loss: 0.7469037870168685, discriminator loss: 0.6892017160654068 Step 30000: Generator loss: 0.7409272351264954, discriminator loss: 0.6925988558530808 Step 30500: Generator loss: 0.7461127021312713, discriminator loss: 0.6910134963989257 Step 31000: Generator loss: 0.7623333480358124, discriminator loss: 0.6848250635862351 Step 31500: Generator loss: 0.7320846046209335, discriminator loss: 0.6922971439361573 Step 32000: Generator loss: 0.7360488106012344, discriminator loss: 0.6958958665132523 Step 32500: Generator loss: 0.7436227219104767, discriminator loss: 0.6927914987802506 Step 33000: Generator loss: 0.7528411923646927, discriminator loss: 0.6868532946109772 Step 33500: Generator loss: 0.7555540499687194, discriminator loss: 0.6819704930782318 Step 34000: Generator loss: 0.7339509303569793, discriminator loss: 0.6947230596542359 Step 34500: Generator loss: 0.7203902735710144, discriminator loss: 0.694019063949585 Step 35000: Generator loss: 0.7161798032522202, discriminator loss: 0.694249948143959 Step 35500: Generator loss: 0.7100930047035218, discriminator loss: 0.6956643009185791 Step 36000: Generator loss: 0.7224245357513428, discriminator loss: 0.692036272764206 Step 36500: Generator loss: 0.7294702612161637, discriminator loss: 0.6839023213386536 Step 37000: Generator loss: 0.7326101566553116, discriminator loss: 0.6864855628013611 Step 37500: Generator loss: 0.7289662526845933, discriminator loss: 0.6891102294921875 Step 38000: Generator loss: 0.7277824294567108, discriminator loss: 0.6930312074422836 Step 38500: Generator loss: 0.7523093600273132, discriminator loss: 0.6822635132074356 Step 39000: Generator loss: 0.7260702294111252, discriminator loss: 0.6836128298044205 Step 39500: Generator loss: 0.7210463825464248, discriminator loss: 0.6865772886276245 Step 40000: Generator loss: 0.7197876414060592, discriminator loss: 0.6861994673013687 Step 40500: Generator loss: 0.7156198496818542, discriminator loss: 0.6897815141677857 Step 41000: Generator loss: 0.7411812788248062, discriminator loss: 0.6876297281980515 Step 41500: Generator loss: 0.7482703533172608, discriminator loss: 0.6831764079332352 Step 42000: Generator loss: 0.7353900390863418, discriminator loss: 0.6809069069623948 Step 42500: Generator loss: 0.726880151629448, discriminator loss: 0.6845077587366104 Step 43000: Generator loss: 0.7335763674974441, discriminator loss: 0.6855163406133652 Step 43500: Generator loss: 0.7247586588859558, discriminator loss: 0.684886796593666 Step 44000: Generator loss: 0.7244187197685241, discriminator loss: 0.6869175283908844 Step 44500: Generator loss: 0.7478935513496399, discriminator loss: 0.6783332238197327 Step 45000: Generator loss: 0.7392684471607208, discriminator loss: 0.687694214463234 Step 45500: Generator loss: 0.7384519840478897, discriminator loss: 0.6806207147836685 Step 46000: Generator loss: 0.7173152709007263, discriminator loss: 0.6894198944568634 Step 46500: Generator loss: 0.7135227386951446, discriminator loss: 0.6902039344310761 Step 47000: Generator loss: 0.7121314022541047, discriminator loss: 0.691226885676384 Step 47500: Generator loss: 0.7153779380321502, discriminator loss: 0.6898772416114807 Step 48000: Generator loss: 0.7112214748859406, discriminator loss: 0.6919035356044769 Step 48500: Generator loss: 0.729472970366478, discriminator loss: 0.6832324341535568 Step 49000: Generator loss: 0.7259864670038223, discriminator loss: 0.6850444099903107 Step 49500: Generator loss: 0.7463545156717301, discriminator loss: 0.6876692290306091 Step 50000: Generator loss: 0.72439306807518, discriminator loss: 0.6840117316246033 Step 50500: Generator loss: 0.7304026707410812, discriminator loss: 0.6828315691947937 Step 51000: Generator loss: 0.735065841794014, discriminator loss: 0.6877049984931946 Step 51500: Generator loss: 0.738693750500679, discriminator loss: 0.6786749280691147 Step 52000: Generator loss: 0.7165734323263169, discriminator loss: 0.688656357049942 Step 52500: Generator loss: 0.7124545810222626, discriminator loss: 0.6885210503339767 Step 53000: Generator loss: 0.7169003388881683, discriminator loss: 0.6898472727537155 Step 53500: Generator loss: 0.7116240389347076, discriminator loss: 0.6890990349054337 Step 54000: Generator loss: 0.7254890002012253, discriminator loss: 0.686066904425621 Step 54500: Generator loss: 0.7279696422815323, discriminator loss: 0.6824959990978241 Step 55000: Generator loss: 0.7243433123826981, discriminator loss: 0.686788556933403 Step 55500: Generator loss: 0.72320248234272, discriminator loss: 0.6819899456501007 Step 56000: Generator loss: 0.7283236463069915, discriminator loss: 0.6813042680025101 Step 56500: Generator loss: 0.7257692145109177, discriminator loss: 0.6882435537576675 Step 57000: Generator loss: 0.7204343225955964, discriminator loss: 0.6905163298845292 Step 57500: Generator loss: 0.7234136379957199, discriminator loss: 0.6828762836456299 Step 58000: Generator loss: 0.7213340125083924, discriminator loss: 0.6852367097139358 Step 58500: Generator loss: 0.7139561972618103, discriminator loss: 0.6901394550800324 Step 59000: Generator loss: 0.7128681792020798, discriminator loss: 0.6899428930282593 Step 59500: Generator loss: 0.7178032584190369, discriminator loss: 0.6901476013660431 Step 60000: Generator loss: 0.7218955677747726, discriminator loss: 0.6866856569051742 Step 60500: Generator loss: 0.7173091459274292, discriminator loss: 0.6909447896480561 Step 61000: Generator loss: 0.7196292532682419, discriminator loss: 0.6888659211397171 Step 61500: Generator loss: 0.7136147793531418, discriminator loss: 0.6911007264852523 Step 62000: Generator loss: 0.7167167031764984, discriminator loss: 0.6874131036996841 Step 62500: Generator loss: 0.7095696296691895, discriminator loss: 0.6924118340015412 Step 63000: Generator loss: 0.7100733149051667, discriminator loss: 0.6894952065944672 Step 63500: Generator loss: 0.7075963083505631, discriminator loss: 0.6918715183734894 Step 64000: Generator loss: 0.7087407541275025, discriminator loss: 0.6912821785211564 Step 64500: Generator loss: 0.7044790136814117, discriminator loss: 0.6919414196014404 Step 65000: Generator loss: 0.7120586842298507, discriminator loss: 0.6889722956418991 Step 65500: Generator loss: 0.7059948451519013, discriminator loss: 0.6913756219148636 Step 66000: Generator loss: 0.7103360829353332, discriminator loss: 0.6888430647850037 Step 66500: Generator loss: 0.7106574136018753, discriminator loss: 0.6923392252922058 Step 67000: Generator loss: 0.7205636972188949, discriminator loss: 0.6888139424324036 Step 67500: Generator loss: 0.7325763144493103, discriminator loss: 0.6851953419446946 Step 68000: Generator loss: 0.7144211075305938, discriminator loss: 0.6894719363451004 Step 68500: Generator loss: 0.7039347168207168, discriminator loss: 0.692310958981514 Step 69000: Generator loss: 0.707789731502533, discriminator loss: 0.690034374833107 Step 69500: Generator loss: 0.7080022550821304, discriminator loss: 0.6897176603078842 Step 70000: Generator loss: 0.706935028553009, discriminator loss: 0.6917025876045227 Step 70500: Generator loss: 0.7035844438076019, discriminator loss: 0.6928271135091781 Step 71000: Generator loss: 0.706664494395256, discriminator loss: 0.6913493415117263 Step 71500: Generator loss: 0.7080443944931031, discriminator loss: 0.6943384435176849 Step 72000: Generator loss: 0.7080535914897919, discriminator loss: 0.6904549078941346 Step 72500: Generator loss: 0.7195642621517181, discriminator loss: 0.6883307158946991 Step 73000: Generator loss: 0.7137477462291717, discriminator loss: 0.6895240060091019 Step 73500: Generator loss: 0.7089026942253113, discriminator loss: 0.6893982688188552 Step 74000: Generator loss: 0.71370064163208, discriminator loss: 0.6885940716266632 Step 74500: Generator loss: 0.7126090573072433, discriminator loss: 0.6913927717208862 Step 75000: Generator loss: 0.7061277792453766, discriminator loss: 0.6915859417915344 Step 75500: Generator loss: 0.7079737706184387, discriminator loss: 0.6918540188074112 Step 76000: Generator loss: 0.7094860315322876, discriminator loss: 0.6909938471317292 Step 76500: Generator loss: 0.7089288998842239, discriminator loss: 0.6928894543647766 Step 77000: Generator loss: 0.7099210443496704, discriminator loss: 0.6881472972631455 Step 77500: Generator loss: 0.7087316303253174, discriminator loss: 0.6922812685966492 Step 78000: Generator loss: 0.7124276860952378, discriminator loss: 0.686549712896347 Step 78500: Generator loss: 0.7118150774240494, discriminator loss: 0.6914097841978073 Step 79000: Generator loss: 0.7061567052602767, discriminator loss: 0.6910830926895142 Step 79500: Generator loss: 0.7130619381666183, discriminator loss: 0.6901648813486099 Step 80000: Generator loss: 0.7189263315200806, discriminator loss: 0.6891602661609649 Step 80500: Generator loss: 0.7099695562124252, discriminator loss: 0.6893361113071441 Step 81000: Generator loss: 0.7043007851839066, discriminator loss: 0.6928421225547791 Step 81500: Generator loss: 0.7055111042261124, discriminator loss: 0.6913482803106308 Step 82000: Generator loss: 0.7107034167051315, discriminator loss: 0.6899493371248245 Step 82500: Generator loss: 0.7072652250528335, discriminator loss: 0.691617219209671 Step 83000: Generator loss: 0.7116999027729034, discriminator loss: 0.6887015362977982 Step 83500: Generator loss: 0.7103397004604339, discriminator loss: 0.6889865008592606 Step 84000: Generator loss: 0.702219740986824, discriminator loss: 0.6927198125123978 Step 84500: Generator loss: 0.7042218887805939, discriminator loss: 0.6910840107202529 Step 85000: Generator loss: 0.7036903632879257, discriminator loss: 0.6948130846023559 Step 85500: Generator loss: 0.7157929112911224, discriminator loss: 0.6877820398807526 Step 86000: Generator loss: 0.703074496269226, discriminator loss: 0.6928336225748062 Step 86500: Generator loss: 0.7018165578842163, discriminator loss: 0.6942198793888092 Step 87000: Generator loss: 0.7056693414449692, discriminator loss: 0.6903062870502472 Step 87500: Generator loss: 0.7070764343738556, discriminator loss: 0.690039452791214 Step 88000: Generator loss: 0.7018579070568085, discriminator loss: 0.6929991252422333 Step 88500: Generator loss: 0.7042791714668274, discriminator loss: 0.6906855113506317 Step 89000: Generator loss: 0.7052908551692962, discriminator loss: 0.6917922974824905 Step 89500: Generator loss: 0.7057228873968124, discriminator loss: 0.6917544984817505 Step 90000: Generator loss: 0.7041428442001343, discriminator loss: 0.6921311345100403 Step 90500: Generator loss: 0.7040341221094132, discriminator loss: 0.6916886166334152 Step 91000: Generator loss: 0.7016080802679062, discriminator loss: 0.6918226274251937 Step 91500: Generator loss: 0.7047490992546082, discriminator loss: 0.6919238156080246 Step 92000: Generator loss: 0.7015802135467529, discriminator loss: 0.6937261604070664 Step 92500: Generator loss: 0.7035572265386582, discriminator loss: 0.6899326649904252 Step 93000: Generator loss: 0.7005916707515717, discriminator loss: 0.6933788905143737 Step 93500: Generator loss: 0.7005794968605041, discriminator loss: 0.6932051202058792 Ended: 2021-04-30 18:15:27.636306 Elapsed: 0:53:48.939130
Before you explore, you should put the generator in eval mode, both in general and so that batch norm doesn't cause you issues and is using its eval statistics.
gen = gen.eval()
You can generate some numbers with your new model! You can add interpolation as well to make it more interesting.
So starting from a image, you will produce intermediate images that look more and more like the ending image until you get to the final image. Your're basically morphing one image into another. You can choose what these two images will be using your conditional GAN.
### Change me! ###
n_interpolation = 9 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)
interpolation_noise = make_some_noise(1, z_dim, device=device).repeat(n_interpolation, 1)
def interpolate_class(first_number, second_number):
first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)
# Calculate the interpolation vector between the two labels
percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label
# Combine the noise and the labels
noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
fake = gen(noise_and_labels)
show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
### Change me! ###
start_plot_number = 1 # Choose the start digit
### Change me! ###
end_plot_number = 5 # Choose the end digit
plt.figure(figsize=(8, 8))
interpolate_class(start_plot_number, end_plot_number)
_ = plt.axis('off')
### Uncomment the following lines of code if you would like to visualize a set of pairwise class
### interpolations for a collection of different numbers, all in a single grid of interpolations.
### You'll also see another visualization like this in the next code block!
# plot_numbers = [2, 3, 4, 5, 7]
# n_numbers = len(plot_numbers)
# plt.figure(figsize=(8, 8))
# for i, first_plot_number in enumerate(plot_numbers):
# for j, second_plot_number in enumerate(plot_numbers):
# plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
# interpolate_class(first_plot_number, second_plot_number)
# plt.axis('off')
# plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
# plt.show()
# plt.close()
Now, what happens if you hold the class constant, but instead you change the noise vector? You can also interpolate the noise vector and generate an image at each step.
n_interpolation = 9 # How many intermediate images you want + 2 (for the start and end image)
This time you're interpolating between the noise instead of the labels
interpolation_label = get_one_hot_labels(torch.Tensor([5]).long(), n_classes).repeat(n_interpolation, 1).float()
def interpolate_noise(first_noise, second_noise):
# This time you're interpolating between the noise instead of the labels
percent_first_noise = torch.linspace(0, 1, n_interpolation)[:, None].to(device)
interpolation_noise = first_noise * percent_first_noise + second_noise * (1 - percent_first_noise)
# Combine the noise and the labels again
noise_and_labels = combine_vectors(interpolation_noise, interpolation_label.to(device))
fake = gen(noise_and_labels)
show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
Generate noise vectors to interpolate between.
### Change me! ###
n_noise = 5 # Choose the number of noise examples in the grid
plot_noises = [get_noise(1, z_dim, device=device) for i in range(n_noise)]
plt.figure(figsize=(8, 8))
for i, first_plot_noise in enumerate(plot_noises):
for j, second_plot_noise in enumerate(plot_noises):
plt.subplot(n_noise, n_noise, i * n_noise + j + 1)
interpolate_noise(first_plot_noise, second_plot_noise)
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()