A Conditional GAN
Build You a Conditional GAN For a Great Good
Imports
# python standard library
from pathlib import Path
import math
# pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
import matplotlib.pyplot as pyplot
import torch
import torch.nn.functional as F
# my stuff
from graeae import Timer
Set Up
The Timer
TIMER = Timer()
The Manual Seed
torch.manual_seed(0)
Plotting
SLUG = "a-conditional-gan"
Helpers
def save_tensor_images(image_tensor: torch.Tensor,
filename: str,
title: str,
folder: str=f"files/posts/gans{SLUG}",
num_images: int=25, size: tuple=(1, 28, 28)):
"""Plot an Image Tensor
Args:
image_tensor: tensor with the values for the image to plot
filename: name to save the file under
folder: path to put the file in
title: title for the image
num_images: how many images from the tensor to use
size: the dimensions for each image
"""
image_tensor = (image_tensor + 1) / 2
image_unflat = image_tensor.detach().cpu()
image_grid = make_grid(image_unflat[:num_images], nrow=5)
pyplot.title(title)
pyplot.grid(False)
pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
pyplot.tick_params(bottom=False, top=False, labelbottom=False,
right=False, left=False, labelleft=False)
pyplot.savefig(folder + filename)
print(f"[[file:{filename}]]")
return
Noise
def make_some_noise(n_samples: int, z_dim: int, device: str='cpu') -> torch.Tensor:
"""Alias for torch.randn
Args:
n_samples: the number of samples to generate
z_dim: the dimension of the noise vector
device: the device type
Returns:
tensor with random numbers from the normal distribution.
"""
return torch.randn(n_samples, z_dim, device=device)
Middle
The Generator
The Generator and Discriminator are the same ones we used before except the z_dim
attribute has been renamed input_dim
to reflect the fact that the data is going to be augmented with the classification information.
class Generator(nn.Module):
"""The DCGAN Generator
Args:
input_dim: the dimension of the input vector
im_chan: the number of channels in the images, fitted for the dataset used
(MNIST is black-and-white, so 1 channel is your default)
hidden_dim: the inner dimension,
"""
def __init__(self, input_dim: int=10, im_chan: int=1, hidden_dim: int=64):
super().__init__()
self.input_dim = input_dim
self.gen = nn.Sequential(
self.make_gen_block(input_dim, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 2, hidden_dim),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels: int, output_channels: int,
kernel_size: int=3, stride: int=2,
final_layer: bool=False) -> nn.Sequential:
"""Creates a block for the generator (sub sequence)
The parts
- a transposed convolution
- a batchnorm (except for in the last layer)
- an activation.
Args:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
Returns:
the sub-sequence of layers
"""
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh(),
)
def forward(self, noise: torch.Tensor) -> torch.Tensor:
"""complete a forward pass of the generator: Given a noise tensor,
Args:
noise: a noise tensor with dimensions (n_samples, z_dim)
Returns:
generated images.
"""
# unsqueeze the noise
x = noise.view(len(noise), self.input_dim, 1, 1)
return self.gen(x)
Discriminator
This differs a little from the DCGAN Discriminator in that the initial hidden dimension output goes up to 64 nodes from 16 in the original.
class Discriminator(nn.Module):
"""The DCGAN Discriminator
Args:
im_chan: the number of channels in the images, fitted for the dataset used
(MNIST is black-and-white, so 1 channel is the default)
hidden_dim: the inner dimension,
"""
def __init__(self, im_chan: int=1, hidden_dim: int=64):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
self.make_disc_block(im_chan, hidden_dim),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
)
return
def make_disc_block(self, input_channels: int, output_channels: int,
kernel_size: int=4, stride: int=2,
final_layer: bool=False) -> nn.Sequential:
"""Make a sub-block of layers for the discriminator
- a convolution
- a batchnorm (except for in the last layer)
- an activation.
Args:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: if true it is the final layer and otherwise not
(affects activation and batchnorm)
"""
# Build the neural block
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2)
)
else: # Final Layer
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""Complete a forward pass of the discriminator
Args:
image: a flattened image tensor with dimension (im_dim)
Returns:
a 1-dimension tensor representing fake/real.
"""
disc_pred = self.disc(image)
return disc_pred.view(len(disc_pred), -1)
The Class Input
One-Hot Encoder
In conditional GANs, the input vector for the generator will also need to include the class information. The class is represented using a one-hot encoded vector where its length is the number of classes and each index represents a class. The vector is all 0's and a 1 on the chosen class. Given the labels of multiple images (e.g. from a batch) and number of classes, please create one-hot vectors for each label. There is a class within the PyTorch functional library that can help you.
- This code can be done in one line.
- pytorch documentation for F.onehot
def get_one_hot_labels(labels: torch.tensor, n_classes: int) -> torch.Tensor:
"""Create one-hot vectors for the labels
Args:
labels: tensor of labels from the dataloader
n_classes: the total number of classes in the dataset
Returns:
a tensor of shape (labels size, num_classes).
"""
#### START CODE HERE ####
return F.one_hot(labels, n_classes)
#### END CODE HERE ####
assert (
get_one_hot_labels(
labels=torch.Tensor([[0, 2, 1]]).long(),
n_classes=3
).tolist() ==
[[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]
]]
)
Combine Vectors
Next, you need to be able to concatenate the one-hot class vector to the noise vector before giving it to the generator. You will also need to do this when adding the class channels to the discriminator.
- This code can also be written in one line.
- See the documentation torch.cat ( Specifically, look at what the
dim
argument oftorch.cat
does)
def combine_vectors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Combine two vectors with shapes (n_samples, ?) and (n_samples, ?).
Args:
x: the first vector.
y: the second vector.
"""
# Note: Make sure this function outputs a float no matter what inputs it receives
#### START CODE HERE ####
combined = torch.cat((x, y), 1).to(torch.float32)
#### END CODE HERE ####
return combined
combined = combine_vectors(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]]));
# Check exact order of elements
assert torch.all(combined == torch.tensor([[1, 2, 5, 6], [3, 4, 7, 8]]))
# Tests that items are of float type
assert (type(combined[0][0].item()) == float)
# Check shapes
combined = combine_vectors(torch.randn(1, 4, 5), torch.randn(1, 8, 5));
assert tuple(combined.shape) == (1, 12, 5)
assert tuple(combine_vectors(torch.randn(1, 10, 12).long(), torch.randn(1, 20, 12).long()).shape) == (1, 30, 12)
Training
First, you will define some new parameters:
- mnistshape: the number of pixels in each MNIST image, which has dimensions 28 x 28 and one channel (because it's black-and-white) so 1 x 28 x 28
- nclasses: the number of classes in MNIST (10, since there are the digits from 0 to 9)
mnist_shape = (1, 28, 28)
n_classes = 10
And you also include the same parameters from before:
- criterion: the loss function
- nepochs: the number of times you iterate through the entire dataset when training
- zdim: the dimension of the noise vector
- displaystep: how often to display/visualize the images
- batchsize: the number of images per forward/backward pass
- lr: the learning rate
- device: the device type
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
data_path = Path("~/pytorch-data/MNIST/").expanduser()
dataloader = DataLoader(
MNIST(data_path, download=False, transform=transform),
batch_size=batch_size,
shuffle=True)
Input Dimensions
Then, you can initialize your generator, discriminator, and optimizers. To do this, you will need to update the input dimensions for both models. For the generator, you will need to calculate the size of the input vector; recall that for conditional GANs, the generator's input is the noise vector concatenated with the class vector. For the discriminator, you need to add a channel for every class.
def get_input_dimensions(z_dim: int, mnist_shape: tuple, n_classes: int) -> tuple:
"""Calculates the size of the conditional input dimensions
Args:
z_dim: the dimension of the noise vector
mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
n_classes: the total number of classes in the dataset, an integer scalar
(10 for MNIST)
Returns:
generator_input_dim: the input dimensionality of the conditional generator,
which takes the noise and class vectors
discriminator_im_chan: the number of input channels to the discriminator
(e.g. C x 28 x 28 for MNIST)
"""
#### START CODE HERE ####
generator_input_dim = z_dim + n_classes
discriminator_im_chan = n_classes + mnist_shape[0]
#### END CODE HERE ####
return generator_input_dim, discriminator_im_chan
def test_input_dims():
gen_dim, disc_dim = get_input_dimensions(23, (12, 23, 52), 9)
assert gen_dim == 32
assert disc_dim == 21
test_input_dims()
Initialize the Objects
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)
gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
def weights_init(m) -> None:
"""Initialize the weights from the normal distribution
Args:
m: object to initialize
"""
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
return
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)
The Training
Now to train, you would like both your generator and your discriminator to know what class of image should be generated.
For example, if you're generating a picture of the number "1", you would need to:
- Tell that to the generator, so that it knows it should be generating a "1"
- Tell that to the discriminator, so that it knows it should be looking at a "1". If the discriminator is told it should be looking at a 1 but sees something that's clearly an 8, it can guess that it's probably fake
cur_step = 0
generator_losses = []
discriminator_losses = []
noise_and_labels = False
fake = False
fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False
with TIMER:
for epoch in range(n_epochs):
# Dataloader returns the batches and the labels
for real, labels in dataloader:
cur_batch_size = len(real)
# Flatten the batch of real images from the dataset
real = real.to(device)
one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])
### Update discriminator ###
# Zero out the discriminator gradients
disc_opt.zero_grad()
# Get noise corresponding to the current batch_size
fake_noise = make_some_noise(cur_batch_size, z_dim, device=device)
# Now you can get the images from the generator
# Steps: 1) Combine the noise vectors and the one-hot labels for the generator
# 2) Generate the conditioned fake images
#### START CODE HERE ####
noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
fake = gen(noise_and_labels)
#### END CODE HERE ####
# Make sure that enough images were generated
assert len(fake) == len(real)
# Check that correct tensors were combined
assert tuple(noise_and_labels.shape) == (cur_batch_size, fake_noise.shape[1] + one_hot_labels.shape[1])
# It comes from the correct generator
assert tuple(fake.shape) == (len(real), 1, 28, 28)
# Now you can get the predictions from the discriminator
# Steps: 1) Create the input for the discriminator
# a) Combine the fake images with image_one_hot_labels,
# remember to detach the generator (.detach()) so you do not backpropagate through it
# b) Combine the real images with image_one_hot_labels
# 2) Get the discriminator's prediction on the fakes as disc_fake_pred
# 3) Get the discriminator's prediction on the reals as disc_real_pred
#### START CODE HERE ####
fake_image_and_labels = combine_vectors(fake.detach(), image_one_hot_labels)
real_image_and_labels = combine_vectors(real, image_one_hot_labels)
disc_fake_pred = disc(fake_image_and_labels)
disc_real_pred = disc(real_image_and_labels)
#### END CODE HERE ####
# Make sure shapes are correct
assert tuple(fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], 28 ,28)
assert tuple(real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1], 28 ,28)
# Make sure that enough predictions were made
assert len(disc_real_pred) == len(real)
# Make sure that the inputs are different
assert torch.any(fake_image_and_labels != real_image_and_labels)
# Shapes must match
assert tuple(fake_image_and_labels.shape) == tuple(real_image_and_labels.shape)
assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)
disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
disc_loss.backward(retain_graph=True)
disc_opt.step()
# Keep track of the average discriminator loss
discriminator_losses += [disc_loss.item()]
### Update generator ###
# Zero out the generator gradients
gen_opt.zero_grad()
fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
# This will error if you didn't concatenate your labels to your image correctly
disc_fake_pred = disc(fake_image_and_labels)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
gen_loss.backward()
gen_opt.step()
# Keep track of the generator losses
generator_losses += [gen_loss.item()]
#
if cur_step % display_step == 0 and cur_step > 0:
gen_mean = sum(generator_losses[-display_step:]) / display_step
disc_mean = sum(discriminator_losses[-display_step:]) / display_step
print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
# show_tensor_images(fake)
# show_tensor_images(real)
step_bins = 20
x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
# num_examples = (len(generator_losses) // step_bins) * step_bins
# plt.plot(
# range(num_examples // step_bins),
# torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
# label="Generator Loss"
# )
# plt.plot(
# range(num_examples // step_bins),
# torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
# label="Discriminator Loss"
# )
# plt.legend()
# plt.show()
elif cur_step == 0:
print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
cur_step += 1
Started: 2021-04-30 17:21:38.697176 Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration! Step 500: Generator loss: 2.2972163581848144, discriminator loss: 0.24098993314430117 Step 1000: Generator loss: 4.111798384666443, discriminator loss: 0.03910111421905458 Step 1500: Generator loss: 5.055200936317444, discriminator loss: 0.017988519712351263 Step 2000: Generator loss: 4.356978572845459, discriminator loss: 0.07956613468006253 Step 2500: Generator loss: 3.1875410358905794, discriminator loss: 0.1825093053430319 Step 3000: Generator loss: 2.7038124163150785, discriminator loss: 0.26903335136175155 Step 3500: Generator loss: 2.3201852326393126, discriminator loss: 0.30360834433138373 Step 4000: Generator loss: 2.3750923416614533, discriminator loss: 0.3380578280091286 Step 4500: Generator loss: 1.9010865423679353, discriminator loss: 0.3774027600288391 Step 5000: Generator loss: 1.9102082657814026, discriminator loss: 0.39759708976745606 Step 5500: Generator loss: 1.6987447377443314, discriminator loss: 0.43035421246290206 Step 6000: Generator loss: 1.6317225174903869, discriminator loss: 0.44889184486865996 Step 6500: Generator loss: 1.4883701887130738, discriminator loss: 0.47211092388629916 Step 7000: Generator loss: 1.4601867563724518, discriminator loss: 0.49546391534805295 Step 7500: Generator loss: 1.3467793072462082, discriminator loss: 0.520910717189312 Step 8000: Generator loss: 1.3130292971134185, discriminator loss: 0.5447795498371124 Step 8500: Generator loss: 1.1705783107280732, discriminator loss: 0.5716251953244209 Step 9000: Generator loss: 1.119933351278305, discriminator loss: 0.594505407333374 Step 9500: Generator loss: 1.0671374444961548, discriminator loss: 0.6109852703809738 Step 10000: Generator loss: 1.042488064646721, discriminator loss: 0.6221774860620498 Step 10500: Generator loss: 1.015932119846344, discriminator loss: 0.6356924445033073 Step 11000: Generator loss: 0.9900961836576462, discriminator loss: 0.6363092860579491 Step 11500: Generator loss: 0.9709519152641296, discriminator loss: 0.642409072637558 Step 12000: Generator loss: 0.9450125323534012, discriminator loss: 0.6527952731847763 Step 12500: Generator loss: 0.9660063650608063, discriminator loss: 0.6581431583166123 Step 13000: Generator loss: 0.8976931695938111, discriminator loss: 0.6624037518501281 Step 13500: Generator loss: 0.8969441390037537, discriminator loss: 0.6631241027116775 Step 14000: Generator loss: 0.902042475938797, discriminator loss: 0.6669842832088471 Step 14500: Generator loss: 0.8777725909948348, discriminator loss: 0.6742114140987396 Step 15000: Generator loss: 0.8425910116434098, discriminator loss: 0.6705276243686676 Step 15500: Generator loss: 0.8403865963220596, discriminator loss: 0.6768578908443451 Step 16000: Generator loss: 0.8556022493839264, discriminator loss: 0.6762306448221207 Step 16500: Generator loss: 0.8764977214336396, discriminator loss: 0.678820540189743 Step 17000: Generator loss: 0.8166215040683746, discriminator loss: 0.6793025200366973 Step 17500: Generator loss: 0.8172620774507523, discriminator loss: 0.6820640956163406 Step 18000: Generator loss: 0.8534817943572998, discriminator loss: 0.679261458158493 Step 18500: Generator loss: 0.814961371421814, discriminator loss: 0.6819399018287658 Step 19000: Generator loss: 0.8046633821725845, discriminator loss: 0.6841713408231735 Step 19500: Generator loss: 0.8040647978782653, discriminator loss: 0.6816601184606552 Step 20000: Generator loss: 0.8188033536672592, discriminator loss: 0.6830086879730225 Step 20500: Generator loss: 0.8088009203672409, discriminator loss: 0.6811848978996277 Step 21000: Generator loss: 0.7853847659826279, discriminator loss: 0.6836997301578521 Step 21500: Generator loss: 0.7798927721977233, discriminator loss: 0.684486163020134 Step 22000: Generator loss: 0.7833593552112579, discriminator loss: 0.685326477766037 Step 22500: Generator loss: 0.8140243920087814, discriminator loss: 0.6838948290348053 Step 23000: Generator loss: 0.7766883851289749, discriminator loss: 0.6896610424518586 Step 23500: Generator loss: 0.7677333387136459, discriminator loss: 0.6847020034790039 Step 24000: Generator loss: 0.7968860776424408, discriminator loss: 0.6859219465255737 Step 24500: Generator loss: 0.7655997285842896, discriminator loss: 0.6866924543380737 Step 25000: Generator loss: 0.7897603868246078, discriminator loss: 0.6862760508060455 Step 25500: Generator loss: 0.7602986326217651, discriminator loss: 0.6839702378511429 Step 26000: Generator loss: 0.7566630525588989, discriminator loss: 0.6884172073602677 Step 26500: Generator loss: 0.7693239089250564, discriminator loss: 0.6848342741727829 Step 27000: Generator loss: 0.7744117819070816, discriminator loss: 0.6884717727899552 Step 27500: Generator loss: 0.7754857275485992, discriminator loss: 0.6871565765142441 Step 28000: Generator loss: 0.7671164673566818, discriminator loss: 0.691150808095932 Step 28500: Generator loss: 0.7767693866491318, discriminator loss: 0.6899988718032837 Step 29000: Generator loss: 0.7584288560152054, discriminator loss: 0.6866833527088165 Step 29500: Generator loss: 0.7469037870168685, discriminator loss: 0.6892017160654068 Step 30000: Generator loss: 0.7409272351264954, discriminator loss: 0.6925988558530808 Step 30500: Generator loss: 0.7461127021312713, discriminator loss: 0.6910134963989257 Step 31000: Generator loss: 0.7623333480358124, discriminator loss: 0.6848250635862351 Step 31500: Generator loss: 0.7320846046209335, discriminator loss: 0.6922971439361573 Step 32000: Generator loss: 0.7360488106012344, discriminator loss: 0.6958958665132523 Step 32500: Generator loss: 0.7436227219104767, discriminator loss: 0.6927914987802506 Step 33000: Generator loss: 0.7528411923646927, discriminator loss: 0.6868532946109772 Step 33500: Generator loss: 0.7555540499687194, discriminator loss: 0.6819704930782318 Step 34000: Generator loss: 0.7339509303569793, discriminator loss: 0.6947230596542359 Step 34500: Generator loss: 0.7203902735710144, discriminator loss: 0.694019063949585 Step 35000: Generator loss: 0.7161798032522202, discriminator loss: 0.694249948143959 Step 35500: Generator loss: 0.7100930047035218, discriminator loss: 0.6956643009185791 Step 36000: Generator loss: 0.7224245357513428, discriminator loss: 0.692036272764206 Step 36500: Generator loss: 0.7294702612161637, discriminator loss: 0.6839023213386536 Step 37000: Generator loss: 0.7326101566553116, discriminator loss: 0.6864855628013611 Step 37500: Generator loss: 0.7289662526845933, discriminator loss: 0.6891102294921875 Step 38000: Generator loss: 0.7277824294567108, discriminator loss: 0.6930312074422836 Step 38500: Generator loss: 0.7523093600273132, discriminator loss: 0.6822635132074356 Step 39000: Generator loss: 0.7260702294111252, discriminator loss: 0.6836128298044205 Step 39500: Generator loss: 0.7210463825464248, discriminator loss: 0.6865772886276245 Step 40000: Generator loss: 0.7197876414060592, discriminator loss: 0.6861994673013687 Step 40500: Generator loss: 0.7156198496818542, discriminator loss: 0.6897815141677857 Step 41000: Generator loss: 0.7411812788248062, discriminator loss: 0.6876297281980515 Step 41500: Generator loss: 0.7482703533172608, discriminator loss: 0.6831764079332352 Step 42000: Generator loss: 0.7353900390863418, discriminator loss: 0.6809069069623948 Step 42500: Generator loss: 0.726880151629448, discriminator loss: 0.6845077587366104 Step 43000: Generator loss: 0.7335763674974441, discriminator loss: 0.6855163406133652 Step 43500: Generator loss: 0.7247586588859558, discriminator loss: 0.684886796593666 Step 44000: Generator loss: 0.7244187197685241, discriminator loss: 0.6869175283908844 Step 44500: Generator loss: 0.7478935513496399, discriminator loss: 0.6783332238197327 Step 45000: Generator loss: 0.7392684471607208, discriminator loss: 0.687694214463234 Step 45500: Generator loss: 0.7384519840478897, discriminator loss: 0.6806207147836685 Step 46000: Generator loss: 0.7173152709007263, discriminator loss: 0.6894198944568634 Step 46500: Generator loss: 0.7135227386951446, discriminator loss: 0.6902039344310761 Step 47000: Generator loss: 0.7121314022541047, discriminator loss: 0.691226885676384 Step 47500: Generator loss: 0.7153779380321502, discriminator loss: 0.6898772416114807 Step 48000: Generator loss: 0.7112214748859406, discriminator loss: 0.6919035356044769 Step 48500: Generator loss: 0.729472970366478, discriminator loss: 0.6832324341535568 Step 49000: Generator loss: 0.7259864670038223, discriminator loss: 0.6850444099903107 Step 49500: Generator loss: 0.7463545156717301, discriminator loss: 0.6876692290306091 Step 50000: Generator loss: 0.72439306807518, discriminator loss: 0.6840117316246033 Step 50500: Generator loss: 0.7304026707410812, discriminator loss: 0.6828315691947937 Step 51000: Generator loss: 0.735065841794014, discriminator loss: 0.6877049984931946 Step 51500: Generator loss: 0.738693750500679, discriminator loss: 0.6786749280691147 Step 52000: Generator loss: 0.7165734323263169, discriminator loss: 0.688656357049942 Step 52500: Generator loss: 0.7124545810222626, discriminator loss: 0.6885210503339767 Step 53000: Generator loss: 0.7169003388881683, discriminator loss: 0.6898472727537155 Step 53500: Generator loss: 0.7116240389347076, discriminator loss: 0.6890990349054337 Step 54000: Generator loss: 0.7254890002012253, discriminator loss: 0.686066904425621 Step 54500: Generator loss: 0.7279696422815323, discriminator loss: 0.6824959990978241 Step 55000: Generator loss: 0.7243433123826981, discriminator loss: 0.686788556933403 Step 55500: Generator loss: 0.72320248234272, discriminator loss: 0.6819899456501007 Step 56000: Generator loss: 0.7283236463069915, discriminator loss: 0.6813042680025101 Step 56500: Generator loss: 0.7257692145109177, discriminator loss: 0.6882435537576675 Step 57000: Generator loss: 0.7204343225955964, discriminator loss: 0.6905163298845292 Step 57500: Generator loss: 0.7234136379957199, discriminator loss: 0.6828762836456299 Step 58000: Generator loss: 0.7213340125083924, discriminator loss: 0.6852367097139358 Step 58500: Generator loss: 0.7139561972618103, discriminator loss: 0.6901394550800324 Step 59000: Generator loss: 0.7128681792020798, discriminator loss: 0.6899428930282593 Step 59500: Generator loss: 0.7178032584190369, discriminator loss: 0.6901476013660431 Step 60000: Generator loss: 0.7218955677747726, discriminator loss: 0.6866856569051742 Step 60500: Generator loss: 0.7173091459274292, discriminator loss: 0.6909447896480561 Step 61000: Generator loss: 0.7196292532682419, discriminator loss: 0.6888659211397171 Step 61500: Generator loss: 0.7136147793531418, discriminator loss: 0.6911007264852523 Step 62000: Generator loss: 0.7167167031764984, discriminator loss: 0.6874131036996841 Step 62500: Generator loss: 0.7095696296691895, discriminator loss: 0.6924118340015412 Step 63000: Generator loss: 0.7100733149051667, discriminator loss: 0.6894952065944672 Step 63500: Generator loss: 0.7075963083505631, discriminator loss: 0.6918715183734894 Step 64000: Generator loss: 0.7087407541275025, discriminator loss: 0.6912821785211564 Step 64500: Generator loss: 0.7044790136814117, discriminator loss: 0.6919414196014404 Step 65000: Generator loss: 0.7120586842298507, discriminator loss: 0.6889722956418991 Step 65500: Generator loss: 0.7059948451519013, discriminator loss: 0.6913756219148636 Step 66000: Generator loss: 0.7103360829353332, discriminator loss: 0.6888430647850037 Step 66500: Generator loss: 0.7106574136018753, discriminator loss: 0.6923392252922058 Step 67000: Generator loss: 0.7205636972188949, discriminator loss: 0.6888139424324036 Step 67500: Generator loss: 0.7325763144493103, discriminator loss: 0.6851953419446946 Step 68000: Generator loss: 0.7144211075305938, discriminator loss: 0.6894719363451004 Step 68500: Generator loss: 0.7039347168207168, discriminator loss: 0.692310958981514 Step 69000: Generator loss: 0.707789731502533, discriminator loss: 0.690034374833107 Step 69500: Generator loss: 0.7080022550821304, discriminator loss: 0.6897176603078842 Step 70000: Generator loss: 0.706935028553009, discriminator loss: 0.6917025876045227 Step 70500: Generator loss: 0.7035844438076019, discriminator loss: 0.6928271135091781 Step 71000: Generator loss: 0.706664494395256, discriminator loss: 0.6913493415117263 Step 71500: Generator loss: 0.7080443944931031, discriminator loss: 0.6943384435176849 Step 72000: Generator loss: 0.7080535914897919, discriminator loss: 0.6904549078941346 Step 72500: Generator loss: 0.7195642621517181, discriminator loss: 0.6883307158946991 Step 73000: Generator loss: 0.7137477462291717, discriminator loss: 0.6895240060091019 Step 73500: Generator loss: 0.7089026942253113, discriminator loss: 0.6893982688188552 Step 74000: Generator loss: 0.71370064163208, discriminator loss: 0.6885940716266632 Step 74500: Generator loss: 0.7126090573072433, discriminator loss: 0.6913927717208862 Step 75000: Generator loss: 0.7061277792453766, discriminator loss: 0.6915859417915344 Step 75500: Generator loss: 0.7079737706184387, discriminator loss: 0.6918540188074112 Step 76000: Generator loss: 0.7094860315322876, discriminator loss: 0.6909938471317292 Step 76500: Generator loss: 0.7089288998842239, discriminator loss: 0.6928894543647766 Step 77000: Generator loss: 0.7099210443496704, discriminator loss: 0.6881472972631455 Step 77500: Generator loss: 0.7087316303253174, discriminator loss: 0.6922812685966492 Step 78000: Generator loss: 0.7124276860952378, discriminator loss: 0.686549712896347 Step 78500: Generator loss: 0.7118150774240494, discriminator loss: 0.6914097841978073 Step 79000: Generator loss: 0.7061567052602767, discriminator loss: 0.6910830926895142 Step 79500: Generator loss: 0.7130619381666183, discriminator loss: 0.6901648813486099 Step 80000: Generator loss: 0.7189263315200806, discriminator loss: 0.6891602661609649 Step 80500: Generator loss: 0.7099695562124252, discriminator loss: 0.6893361113071441 Step 81000: Generator loss: 0.7043007851839066, discriminator loss: 0.6928421225547791 Step 81500: Generator loss: 0.7055111042261124, discriminator loss: 0.6913482803106308 Step 82000: Generator loss: 0.7107034167051315, discriminator loss: 0.6899493371248245 Step 82500: Generator loss: 0.7072652250528335, discriminator loss: 0.691617219209671 Step 83000: Generator loss: 0.7116999027729034, discriminator loss: 0.6887015362977982 Step 83500: Generator loss: 0.7103397004604339, discriminator loss: 0.6889865008592606 Step 84000: Generator loss: 0.702219740986824, discriminator loss: 0.6927198125123978 Step 84500: Generator loss: 0.7042218887805939, discriminator loss: 0.6910840107202529 Step 85000: Generator loss: 0.7036903632879257, discriminator loss: 0.6948130846023559 Step 85500: Generator loss: 0.7157929112911224, discriminator loss: 0.6877820398807526 Step 86000: Generator loss: 0.703074496269226, discriminator loss: 0.6928336225748062 Step 86500: Generator loss: 0.7018165578842163, discriminator loss: 0.6942198793888092 Step 87000: Generator loss: 0.7056693414449692, discriminator loss: 0.6903062870502472 Step 87500: Generator loss: 0.7070764343738556, discriminator loss: 0.690039452791214 Step 88000: Generator loss: 0.7018579070568085, discriminator loss: 0.6929991252422333 Step 88500: Generator loss: 0.7042791714668274, discriminator loss: 0.6906855113506317 Step 89000: Generator loss: 0.7052908551692962, discriminator loss: 0.6917922974824905 Step 89500: Generator loss: 0.7057228873968124, discriminator loss: 0.6917544984817505 Step 90000: Generator loss: 0.7041428442001343, discriminator loss: 0.6921311345100403 Step 90500: Generator loss: 0.7040341221094132, discriminator loss: 0.6916886166334152 Step 91000: Generator loss: 0.7016080802679062, discriminator loss: 0.6918226274251937 Step 91500: Generator loss: 0.7047490992546082, discriminator loss: 0.6919238156080246 Step 92000: Generator loss: 0.7015802135467529, discriminator loss: 0.6937261604070664 Step 92500: Generator loss: 0.7035572265386582, discriminator loss: 0.6899326649904252 Step 93000: Generator loss: 0.7005916707515717, discriminator loss: 0.6933788905143737 Step 93500: Generator loss: 0.7005794968605041, discriminator loss: 0.6932051202058792 Ended: 2021-04-30 18:15:27.636306 Elapsed: 0:53:48.939130
Exploration
Before you explore, you should put the generator in eval mode, both in general and so that batch norm doesn't cause you issues and is using its eval statistics.
gen = gen.eval()
Changing the Class Vector
You can generate some numbers with your new model! You can add interpolation as well to make it more interesting.
So starting from a image, you will produce intermediate images that look more and more like the ending image until you get to the final image. Your're basically morphing one image into another. You can choose what these two images will be using your conditional GAN.
### Change me! ###
n_interpolation = 9 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)
interpolation_noise = make_some_noise(1, z_dim, device=device).repeat(n_interpolation, 1)
def interpolate_class(first_number, second_number):
first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)
# Calculate the interpolation vector between the two labels
percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label
# Combine the noise and the labels
noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
fake = gen(noise_and_labels)
show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
### Change me! ###
start_plot_number = 1 # Choose the start digit
### Change me! ###
end_plot_number = 5 # Choose the end digit
plt.figure(figsize=(8, 8))
interpolate_class(start_plot_number, end_plot_number)
_ = plt.axis('off')
### Uncomment the following lines of code if you would like to visualize a set of pairwise class
### interpolations for a collection of different numbers, all in a single grid of interpolations.
### You'll also see another visualization like this in the next code block!
# plot_numbers = [2, 3, 4, 5, 7]
# n_numbers = len(plot_numbers)
# plt.figure(figsize=(8, 8))
# for i, first_plot_number in enumerate(plot_numbers):
# for j, second_plot_number in enumerate(plot_numbers):
# plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
# interpolate_class(first_plot_number, second_plot_number)
# plt.axis('off')
# plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
# plt.show()
# plt.close()
Changing the Noise Vector
Now, what happens if you hold the class constant, but instead you change the noise vector? You can also interpolate the noise vector and generate an image at each step.
n_interpolation = 9 # How many intermediate images you want + 2 (for the start and end image)
This time you're interpolating between the noise instead of the labels
interpolation_label = get_one_hot_labels(torch.Tensor([5]).long(), n_classes).repeat(n_interpolation, 1).float()
def interpolate_noise(first_noise, second_noise):
# This time you're interpolating between the noise instead of the labels
percent_first_noise = torch.linspace(0, 1, n_interpolation)[:, None].to(device)
interpolation_noise = first_noise * percent_first_noise + second_noise * (1 - percent_first_noise)
# Combine the noise and the labels again
noise_and_labels = combine_vectors(interpolation_noise, interpolation_label.to(device))
fake = gen(noise_and_labels)
show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
Generate noise vectors to interpolate between.
### Change me! ###
n_noise = 5 # Choose the number of noise examples in the grid
plot_noises = [get_noise(1, z_dim, device=device) for i in range(n_noise)]
plt.figure(figsize=(8, 8))
for i, first_plot_noise in enumerate(plot_noises):
for j, second_plot_noise in enumerate(plot_noises):
plt.subplot(n_noise, n_noise, i * n_noise + j + 1)
interpolate_noise(first_plot_noise, second_plot_noise)
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()