Wasserstein GAN With Gradient Penalty

A Wasserstein GAN with Gradient Penalty (WGAN-GP)

We're going to build a Wasserstein GAN with Gradient Penalty (WGAN-GP) that solves some of the stability issues with GANs. Specifically, we'll use a special kind of loss function known as the W-loss, where W stands for Wasserstein, and gradient penalties to prevent mode collapse (see Wasserstein Metric).

Wasserstein is named after a mathematician at Penn State, Leonid Vaseršteĭn.

Imports

# python
from pathlib import Path

# from pypi
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

import holoviews
import matplotlib.pyplot as pyplot
import torch

# my stuff
from graeae import EmbedHoloviews, Timer

Set Up

The Random Seed

torch.manual_seed(0)

Plotting and the Timer

TIMER = Timer()
SLUG = "wasserstein-gan-with-gradient-penalty"

Helper Functions

def save_tensor_images(image_tensor: torch.Tensor,
                       filename: str, 
                       title: str,
                       folder: str=f"files/posts/gans{SLUG}",
                       num_images: int=25, size: tuple=(1, 28, 28)):
    """Plot an Image Tensor

    Args:
     image_tensor: tensor with the values for the image to plot
     filename: name to save the file under
     folder: path to put the file in
     title: title for the image
     num_images: how many images from the tensor to use
     size: the dimensions for each image
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    pyplot.title(title)
    pyplot.grid(False)
    pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())
    pyplot.tick_params(bottom=False, top=False, labelbottom=False,
                       right=False, left=False, labelleft=False)
    pyplot.savefig(folder + filename)
    print(f"[[file:{filename}]]")
    return
def holoviews_image(image: torch.tensor) -> holoviews.Image:
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    return holoview.Image(image_grid)

Gradient Hook

This helps to keep track of the gradient for plotting

def make_grad_hook() -> tuple:
    """
    Function to keep track of gradients for visualization purposes, 
    which fills the grads list when using model.apply(grad_hook).
    """
    grads = []
    def grad_hook(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            grads.append(m.weight.grad)
    return grads, grad_hook

Noise

def make_noise(n_samples: int, z_dim: int, device: str='cpu') -> torch.Tensor:
    """Alias for torch.randn

    Args:
      n_samples: the number of samples to generate
      z_dim: the dimension of the noise vector
      device: the device type

    Returns:
     tensor with random numbers from the normal distribution.
    """
    return torch.randn(n_samples, z_dim, device=device)

Middle

The Generator

This is the Deep Convolutional GAN from before.

class Generator(nn.Module):
    """The DCGAN Generator

    Args:
       input_dim: the dimension of the input vector
       im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is your default)
       hidden_dim: the inner dimension,
    """
    def __init__(self, z_dim: int=10, im_chan: int=1, hidden_dim: int=64):
        super().__init__()
        self.input_dim = input_dim

        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels: int, output_channels: int,
                       kernel_size: int=3, stride: int=2,
                       final_layer: bool=False) -> nn.Sequential:
        """Creates a block for the generator (sub sequence)

       The parts
        - a transposed convolution
        - a batchnorm (except for in the last layer)
        - an activation.

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)

       Returns:
        the sub-sequence of layers
       """

        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        """complete a forward pass of the generator: Given a noise tensor, 

       Args:
        noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns:
        generated images.
       """
        # unsqueeze the noise
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

The Critic

This is also essentially the same as our Discriminator class from before.

class Critic(nn.Module):
    """
    Critic Class

    Args:
       im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is your default)
       hidden_dim: the inner dimension
    """
    def __init__(self, im_chan: int=1, hidden_dim: int=64):
        super().__init__()
        self.crit = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim),
            self.make_crit_block(hidden_dim, hidden_dim * 2),
            self.make_crit_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_crit_block(self, input_channels: int, output_channels: int,
                        kernel_size: int=4, stride: int=2,
                        final_layer: bool=False) -> nn.Sequential:
        """Creates a sub-block for the network

        - a convolution
        - a batchnorm (except in the final layer)
        - an activation (except in the final layer).

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)
       """
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size,
                          stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size,
                          stride),
            )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Run a forward pass of the critic

       Args:
           image: a flattened image tensor with dimension (im_chan)

       Returns:
        a 1-dimension tensor representing fake/real.
       """
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)

Training

Hyperparameters

As usual, we'll start by setting the parameters:

  • nepochs: the number of times you iterate through the entire dataset when training
  • zdim: the dimension of the noise vector
  • displaystep: how often to display/visualize the images
  • batchsize: the number of images per forward/backward pass
  • lr: the learning rate
  • beta1, beta2: the momentum terms
  • clambda: weight of the gradient penalty
  • critrepeats: number of times to update the critic per generator update - there are more details about this in the Putting It All Together section
  • device: the device type
n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

The Data

Once again we'll be using the MNIST dataset.

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

data_path = Path("~/pytorch-data/MNIST/").expanduser()
dataloader = DataLoader(
    MNIST(data_path, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Setup For Training

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device) 
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
    return
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

The Gradient

Calculating the gradient penalty can be broken into two functions: (1) compute the gradient with respect to the images and (2) compute the gradient penalty given the gradient.

You can start by getting the gradient. The gradient is computed by first creating a mixed image. This is done by weighing the fake and real image using epsilon and then adding them together. Once you have the intermediate image, you can get the critic's output on the image. Finally, you compute the gradient of the critic score's on the mixed images (output) with respect to the pixels of the mixed images (input).

def get_gradient(crit: Critic, real: torch.Tensor, fake: torch.Tensor,
                 epsilon: torch.Tensor) -> torch.tensor:
    """Gradient of the critic's scores with respect to mixes of real and fake images.

    Args:
       crit: the critic model
       real: a batch of real images
       fake: a batch of fake images
       epsilon: a vector of the uniformly random proportions of real/fake per mixed image

    Returns:
       gradient: the gradient of the critic's scores, with respect to the mixed image
    """
    # Mix the images together
    mixed_images = real * epsilon + fake * (1 - epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = crit(mixed_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        # Note: You need to take the gradient of outputs with respect to inputs.
        #### START CODE HERE ####
        inputs = mixed_images,
        outputs = mixed_scores,
        #### END CODE HERE ####
        # These other parameters have to do with how the pytorch autograd engine works
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient
  • Unit Tests
    def test_get_gradient(image_shape):
        real = torch.randn(*image_shape, device=device) + 1
        fake = torch.randn(*image_shape, device=device) - 1
        epsilon_shape = [1 for _ in image_shape]
        epsilon_shape[0] = image_shape[0]
        epsilon = torch.rand(epsilon_shape, device=device).requires_grad_()
        gradient = get_gradient(crit, real, fake, epsilon)
        assert tuple(gradient.shape) == image_shape
        assert gradient.max() > 0
        assert gradient.min() < 0
        return gradient
    
    gradient = test_get_gradient((256, 1, 28, 28))
    

The Gradient Penalty

The second function you need to complete is to compute the gradient penalty given the gradient. First, you calculate the magnitude of each image's gradient. The magnitude of a gradient is also called the norm. Then, you calculate the penalty by squaring the distance between each magnitude and the ideal norm of 1 and taking the mean of all the squared distances.

  1. Make sure you take the mean at the end.
  2. Note that the magnitude of each gradient has already been calculated for you.
def gradient_penalty(gradient: torch.Tensor) -> torch.Tensor:
    """Calculate the size of each image's gradient
    and penalize the mean quadratic distance of each magnitude to 1.

    Args:
       gradient: the gradient of the critic's scores, with respect to the mixed image

    Returns:
       penalty: the gradient penalty
    """
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)

    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)

    # Penalize the mean squared distance of the gradient norms from 1
    penalty = torch.mean(torch.square(gradient_norm - 1))
    return penalty
  • Unit Testing
    def test_gradient_penalty(image_shape: tuple):
        bad_gradient = torch.zeros(*image_shape)
        bad_gradient_penalty = gradient_penalty(bad_gradient)
        assert torch.isclose(bad_gradient_penalty, torch.tensor(1.))
    
        image_size = torch.prod(torch.Tensor(image_shape[1:]))
        good_gradient = torch.ones(*image_shape) / torch.sqrt(image_size)
        good_gradient_penalty = gradient_penalty(good_gradient)
        assert torch.isclose(good_gradient_penalty, torch.tensor(0.))
    
        random_gradient = test_get_gradient(image_shape)
        random_gradient_penalty = gradient_penalty(random_gradient)
        assert torch.abs(random_gradient_penalty - 1) < 0.1
    
    test_gradient_penalty((256, 1, 28, 28))
    

Losses

Next, you need to calculate the loss for the generator and the critic.

  • Generator Loss

    For the generator, the loss is calculated by maximizing the critic's prediction on the generator's fake images. The argument has the scores for all fake images in the batch, but you will use the mean of them.

    1. This can be written in one line.
    2. This is the negative of the mean of the critic's scores.
    def get_gen_loss(crit_fake_pred: torch.Tensor) -> torch.Tensor:
        """loss of generator given critic's scores of generator's fake images.
    
        Args:
           crit_fake_pred: the critic's scores of the fake images
    
        Returns:
           gen_loss: a scalar loss value for the current batch of the generator
        """
        return -torch.mean(crit_fake_pred)
    
    assert torch.isclose(
        get_gen_loss(torch.tensor(1.)), torch.tensor(-1.0)
    )
    
    assert torch.isclose(
        get_gen_loss(torch.rand(10000)), torch.tensor(-0.5), 0.05
    )
    
  • The Critic Loss

    For the critic, the loss is calculated by maximizing the distance between the critic's predictions on the real images and the predictions on the fake images while also adding a gradient penalty. The gradient penalty is weighed according to lambda. The arguments are the scores for all the images in the batch, and you will use the mean of them.

    1. The higher the mean fake score, the higher the critic's loss is.
    2. What does this suggest about the mean real score?
    3. The higher the gradient penalty, the higher the critic's loss is, proportional to lambda.
    def get_crit_loss(crit_fake_pred: torch.Tensor, crit_real_pred: torch.Tensor,
                      gp: torch.Tensor, c_lambda: torch.Tensor) -> torch.Tensor:
        """loss of a critic given critic's scores for fake and real images,
        the gradient penalty, and gradient penalty weight.
    
        Args:
           crit_fake_pred: the critic's scores of the fake images
           crit_real_pred: the critic's scores of the real images
           gp: the unweighted gradient penalty
           c_lambda: the current weight of the gradient penalty 
    
        Returns:
           crit_loss: a scalar for the critic's loss, accounting for the relevant factors
        """
        return torch.mean(crit_fake_pred - crit_real_pred  + gp * c_lambda)
    
    assert torch.isclose(
        get_crit_loss(torch.tensor(1.), torch.tensor(2.), torch.tensor(3.), 0.1),
        torch.tensor(-0.7)
    )
    assert torch.isclose(
        get_crit_loss(torch.tensor(20.), torch.tensor(-20.), torch.tensor(2.), 10),
        torch.tensor(60.)
    )
    

Running the Training

Before you put everything together, there are a few things to note.

  1. Even on GPU, the training will run more slowly than previous labs because the gradient penalty requires you to compute the gradient of a gradient – this means potentially a few minutes per epoch! For best results, run this for as long as you can while on GPU.
  2. One important difference from earlier versions is that you will update the critic multiple times every time you update the generator This helps prevent the generator from overpowering the critic. Sometimes, you might see the reverse, with the generator updated more times than the critic. This depends on architectural (e.g. the depth and width of the network) and algorithmic choices (e.g. which loss you're using).
  3. WGAN-GP isn't necessarily meant to improve overall performance of a GAN, but just increases stability and avoids mode collapse. In general, a WGAN will be able to train in a much more stable way than the vanilla DCGAN from last assignment, though it will generally run a bit slower. You should also be able to train your model for more epochs without it collapsing.
def update_critic(critic, critic_optimizer, generator, generator_optimizer, batch_size, z_dim, real):
    critic_optimizer.zero_grad()
    fake_noise = make_noise(batch_size, z_dim, device=device)
    fake = generator(fake_noise)
    crit_fake_pred = critic(fake.detach())
    crit_real_pred = critic(real)

    epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
    gradient = get_gradient(critic, real, fake.detach(), epsilon)
    gp = gradient_penalty(gradient)
    crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

    # Keep track of the average critic loss in this batch
    mean_iteration_critic_loss = crit_loss.detach().item() / crit_repeats
    # Update gradients
    crit_loss.backward()
    # Update optimizer
    crit_opt.step()
    return mean_iteration_critic_loss, fake
def update_generator(generator, generator_optimizer, critic, critic_optimizer,
                     batch_size, z_dim):
        generator_optimizer.zero_grad()
        fake_noise_2 = make_noise(batch_size, z_dim, device=device)
        fake_2 = generator(fake_noise_2)
        crit_fake_pred = critic(fake_2)

        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        generator_optimizer.step()
        return [gen_loss.detach().item()]
cur_step = 0
generator_losses = []
critic_losses = []
fakes = []

with TIMER:
    for epoch in range(n_epochs):
        # Dataloader returns the batches
        for real, _ in dataloader:
            cur_batch_size = len(real)
            real = real.to(device)

            mean_iteration_critic_loss = 0
            for _ in range(crit_repeats):
                ### Update critic ###
                this_loss, fake = update_critic(crit, crit_opt, gen, gen_opt,
                                                cur_batch_size, z_dim, real)
                mean_iteration_critic_loss += this_loss
            critic_losses += [mean_iteration_critic_loss]

            ### Update generator ###
            # Keep track of the average generator loss
            generator_losses += update_generator(gen, gen_opt, crit, crit_opt,
                                                 cur_batch_size, z_dim)

            ### Visualization code ###
            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                crit_mean = sum(critic_losses[-display_step:]) / display_step
                print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
                fakes.append(fake)
                #show_tensor_images(fake)
                # show_tensor_images(real)
                # step_bins = 20
                #num_examples = (len(generator_losses) // step_bins) * step_bins
                #plt.plot(
                #    range(num_examples // step_bins), 
                #    torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                #    label="Generator Loss"
                #)
                #plt.plot(
                #    range(num_examples // step_bins), 
                #    torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                #    label="Critic Loss"
                #)
                #plt.legend()
                #plt.show()

            cur_step += 1
Started: 2021-04-23 16:44:37.086571
Step 50: Generator loss: 1.2940945455431938, critic loss: -2.5389487731456755
Step 100: Generator loss: 1.8233803486824036, critic loss: -10.170887191772463
Step 150: Generator loss: -0.8236922709643841, critic loss: -25.889275665283208
Step 200: Generator loss: -1.9489177632331849, critic loss: -57.93669644165039
Step 250: Generator loss: -1.6910316547751427, critic loss: -98.02721130371094
Step 300: Generator loss: -1.057899413406849, critic loss: -148.77607403564457
Step 350: Generator loss: -1.0930944073200226, critic loss: -199.94886077880858
Step 400: Generator loss: 1.900166620016098, critic loss: -245.53067184448247
Step 450: Generator loss: -18.928784263134002, critic loss: -251.46439450645448
Step 500: Generator loss: -7.688082475662231, critic loss: -289.45334830856325
Step 550: Generator loss: 12.447209596633911, critic loss: -395.351733947754
Step 600: Generator loss: 4.604712443947792, critic loss: -442.96986193847647
Step 650: Generator loss: 2.1788939160108565, critic loss: -480.044010131836
Step 700: Generator loss: 2.979072951376438, critic loss: -519.4769331054689
Step 750: Generator loss: -49.77768729448319, critic loss: -406.99980457305907
Step 800: Generator loss: 0.28986886143684387, critic loss: -444.8244698066711
Step 850: Generator loss: 31.1217813873291, critic loss: -608.1500103759765
Step 900: Generator loss: 12.006675623655319, critic loss: -632.0770750808719
Step 950: Generator loss: -0.15041383981704712, critic loss: -659.3277660064699
Step 1000: Generator loss: 15.936325817108154, critic loss: -629.952421447754
Step 1050: Generator loss: -43.25309041261673, critic loss: -504.54743419075004
Step 1100: Generator loss: -127.80617136001587, critic loss: -347.7993973159789
Step 1150: Generator loss: 4.186352119445801, critic loss: -461.6966292152405
Step 1200: Generator loss: 19.471285017728807, critic loss: -417.6742295103073
Step 1250: Generator loss: 34.04052387237549, critic loss: -327.74495936584475
Step 1300: Generator loss: -61.267093954086306, critic loss: -114.96264076042176
Step 1350: Generator loss: -56.96540081501007, critic loss: -257.8397505912781
Step 1400: Generator loss: -58.51407446861267, critic loss: -284.2404485015868
Step 1450: Generator loss: -31.23556293010712, critic loss: -282.15282668590544
Step 1500: Generator loss: 21.97936663866043, critic loss: -201.8184239835738
Step 1550: Generator loss: -35.051265001297, critic loss: -268.2542330398559
Step 1600: Generator loss: -13.768656857013703, critic loss: -201.92625104904172
Step 1650: Generator loss: 22.134875717163087, critic loss: -222.15251140356065
Step 1700: Generator loss: -33.80421092987061, critic loss: -196.00927429389947
Step 1750: Generator loss: -57.25435597419739, critic loss: -182.85244289588928
Step 1800: Generator loss: -41.60410815238953, critic loss: -213.254286611557
Step 1850: Generator loss: -4.978743267059326, critic loss: -101.88668561553959
Step 1900: Generator loss: 43.375376815795896, critic loss: 24.468120357513428
Step 1950: Generator loss: 37.55927352905273, critic loss: 19.142875072479246
Step 2000: Generator loss: 30.793880767822266, critic loss: 27.632160606384268
Step 2050: Generator loss: 28.9916410446167, critic loss: 37.41749234771728
Step 2100: Generator loss: 28.57459102630615, critic loss: 36.46667390441895
Step 2150: Generator loss: 27.179994583129883, critic loss: 37.36057964324953
Step 2200: Generator loss: 26.722407608032228, critic loss: 36.42123816680908
Step 2250: Generator loss: 26.215636711120606, critic loss: 35.10568865203857
Step 2300: Generator loss: 25.28977954864502, critic loss: 38.4949776916504
Step 2350: Generator loss: 25.161714172363283, critic loss: 30.91700393295288
Step 2400: Generator loss: 25.609521713256836, critic loss: 27.127794273376463
Step 2450: Generator loss: 26.457210426330565, critic loss: 23.25596778869629
Step 2500: Generator loss: 27.144473686218262, critic loss: 18.423582084655763
Step 2550: Generator loss: 28.104863624572754, critic loss: 16.720462280273438
Step 2600: Generator loss: 29.460466690063477, critic loss: 13.846090631484987
Step 2650: Generator loss: 31.16196632385254, critic loss: 10.717047594070436
Step 2700: Generator loss: 32.86851013183594, critic loss: 8.973742393493652
Step 2750: Generator loss: 33.90616256713867, critic loss: 9.844469717025756
Step 2800: Generator loss: 34.65669334411621, critic loss: 8.557393852233888
Step 2850: Generator loss: 35.84923110961914, critic loss: 6.227309632301333
Step 2900: Generator loss: 37.1290372467041, critic loss: 4.664727992773057
Step 2950: Generator loss: 39.00773422241211, critic loss: 3.6960949053764343
Step 3000: Generator loss: 41.04932693481445, critic loss: 3.064339481592179
Step 3050: Generator loss: 43.54303398132324, critic loss: 1.5976664029359815
Step 3100: Generator loss: 46.25879165649414, critic loss: 0.43558707976341254
Step 3150: Generator loss: 48.358483200073245, critic loss: -0.8735819962918758
Step 3200: Generator loss: 49.9193138885498, critic loss: -1.9399951877593993
Step 3250: Generator loss: 50.604149169921875, critic loss: -2.96596682035923
Step 3300: Generator loss: 51.37260269165039, critic loss: -4.266795755624772
Step 3350: Generator loss: 50.53414665222168, critic loss: -6.2572406907081595
Step 3400: Generator loss: 49.34995780944824, critic loss: -8.031075536847114
Step 3450: Generator loss: 46.14337966918945, critic loss: -8.019683789610863
Step 3500: Generator loss: 42.769298782348635, critic loss: -9.498445952415468
Step 3550: Generator loss: 37.38293798446655, critic loss: -9.02791331624985
Step 3600: Generator loss: 32.84453460693359, critic loss: -9.934100509524345
Step 3650: Generator loss: 29.88087886810303, critic loss: -9.069164658904075
Step 3700: Generator loss: 27.295934791564942, critic loss: -12.109804625511167
Step 3750: Generator loss: 23.694135398864745, critic loss: -14.327697192192076
Step 3800: Generator loss: 22.836445541381835, critic loss: -15.450897558450697
Step 3850: Generator loss: 21.66964967727661, critic loss: -18.371595690727236
Step 3900: Generator loss: 22.644691734313966, critic loss: -18.472765784740442
Step 3950: Generator loss: 23.275020160675048, critic loss: -14.622903740763663
Step 4000: Generator loss: 20.404177145957945, critic loss: -20.265531128883364
Step 4050: Generator loss: 20.57322360277176, critic loss: -22.811122689247135
Step 4100: Generator loss: 20.653975734710695, critic loss: -21.081045699119564
Step 4150: Generator loss: 22.07396845817566, critic loss: -25.1140656299591
Step 4200: Generator loss: 23.147041385173797, critic loss: -25.637423175573346
Step 4250: Generator loss: 24.7466512966156, critic loss: -27.446938713431358
Step 4300: Generator loss: 23.155011949539183, critic loss: -29.866371290445326
Step 4350: Generator loss: 28.670740413665772, critic loss: -27.526438851594932
Step 4400: Generator loss: 28.197952184677124, critic loss: -32.777981144189845
Step 4450: Generator loss: 30.352355518341064, critic loss: -27.1594803442955
Step 4500: Generator loss: 28.54464930534363, critic loss: -33.94081681919097
Step 4550: Generator loss: 30.315768175125122, critic loss: -32.86432695555688
Step 4600: Generator loss: 31.542511186599732, critic loss: -30.20407930350304
Step 4650: Generator loss: 32.1046596121788, critic loss: -25.409390352845193
Step 4700: Generator loss: 32.14258025169372, critic loss: -31.69375462341309
Step 4750: Generator loss: 34.99601099014282, critic loss: -17.207461384415634
Step 4800: Generator loss: 34.72456073760986, critic loss: -28.68983098757266
Step 4850: Generator loss: 43.15867195129395, critic loss: -3.741025509417056
Step 4900: Generator loss: 39.205870933532715, critic loss: -10.995340047717095
Step 4950: Generator loss: 33.214964599609374, critic loss: -22.35341439080238
Step 5000: Generator loss: 36.83505029678345, critic loss: -22.059852074146274
Step 5050: Generator loss: 44.310142288208006, critic loss: -9.833503689646719
Step 5100: Generator loss: 46.455570983886716, critic loss: -6.97827914196253
Step 5150: Generator loss: 50.3965446472168, critic loss: 2.86564082187414
Step 5200: Generator loss: 49.87795219421387, critic loss: -1.3452879690229889
Step 5250: Generator loss: 47.53674819946289, critic loss: -2.096805039405823
Step 5300: Generator loss: 46.8746314239502, critic loss: -3.2593628435134883
Step 5350: Generator loss: 45.44812057495117, critic loss: -8.15779336643219
Step 5400: Generator loss: 44.419895820617675, critic loss: -14.570247013330457
Step 5450: Generator loss: 46.02410781860352, critic loss: -15.177982830524446
Step 5500: Generator loss: 49.54875686645508, critic loss: -9.89209368979931
Step 5550: Generator loss: 48.06167510986328, critic loss: -14.110691767692567
Step 5600: Generator loss: 49.201857833862306, critic loss: -14.137419148623945
Step 5650: Generator loss: 50.152088012695316, critic loss: -12.306397112727165
Step 5700: Generator loss: 48.29638786315918, critic loss: -16.661144974470133
Step 5750: Generator loss: 48.57353067398071, critic loss: -14.890159791767603
Step 5800: Generator loss: 49.75064552307129, critic loss: -18.844482659339906
Step 5850: Generator loss: 60.04904914855957, critic loss: -6.717597324132919
Step 5900: Generator loss: 51.537723999023434, critic loss: -16.97626993632317
Step 5950: Generator loss: 53.64197952270508, critic loss: -17.934735801696778
Step 6000: Generator loss: 58.61811660766602, critic loss: -12.544874910593034
Step 6050: Generator loss: 57.9530167388916, critic loss: -12.869983579158779
Step 6100: Generator loss: 58.112417755126955, critic loss: -14.860800614833833
Step 6150: Generator loss: 59.45550857543945, critic loss: -16.21854728984833
Step 6200: Generator loss: 61.55990020751953, critic loss: -13.752459713578226
Step 6250: Generator loss: 63.91949012756348, critic loss: -15.32866345870495
Step 6300: Generator loss: 61.11529357910156, critic loss: -19.138810309529305
Step 6350: Generator loss: 68.78476165771484, critic loss: -3.858711770117282
Step 6400: Generator loss: 72.07508163452148, critic loss: -3.3317795319557204
Step 6450: Generator loss: 62.11038558959961, critic loss: -12.74781008577347
Step 6500: Generator loss: 66.10368064880372, critic loss: -13.66576182627678
Step 6550: Generator loss: 62.73857864379883, critic loss: -19.79733684468269
Step 6600: Generator loss: 64.86283889770507, critic loss: -15.91535943055153
Step 6650: Generator loss: 65.02771781921386, critic loss: -16.515603628635407
Step 6700: Generator loss: 73.10651649475098, critic loss: -7.974747009277344
Step 6750: Generator loss: 69.39200439453126, critic loss: -12.647881946563723
Step 6800: Generator loss: 70.61859390258789, critic loss: -14.981548887073998
Step 6850: Generator loss: 71.39209846496583, critic loss: -12.02037605035305
Step 6900: Generator loss: 68.91642692565918, critic loss: -17.377452049493794
Step 6950: Generator loss: 73.83714424133301, critic loss: -14.842290714025498
Step 7000: Generator loss: 76.0492682647705, critic loss: -4.022153543114662
Step 7050: Generator loss: 73.60314575195312, critic loss: -11.167652189731598
Step 7100: Generator loss: 73.69744178771973, critic loss: -16.215790304422377
Step 7150: Generator loss: 73.02161018371582, critic loss: -11.844917020320892
Step 7200: Generator loss: 84.43860961914062, critic loss: -4.338678442955016
Step 7250: Generator loss: 72.4216611480713, critic loss: -16.95018665671349
Step 7300: Generator loss: 75.08161041259766, critic loss: -13.94019952297211
Step 7350: Generator loss: 76.7044221496582, critic loss: -14.254385577440262
Step 7400: Generator loss: 81.03584564208984, critic loss: -3.171723330259324
Step 7450: Generator loss: 80.19454528808593, critic loss: -6.323260527610778
Step 7500: Generator loss: 74.55620361328126, critic loss: -8.62027923491597
Step 7550: Generator loss: 84.05591217041015, critic loss: -3.5706960783004775
Step 7600: Generator loss: 81.2258724975586, critic loss: -8.142396178722382
Step 7650: Generator loss: 73.19812255859375, critic loss: -16.196065732836722
Step 7700: Generator loss: 74.52944702148437, critic loss: -15.7419521817565
Step 7750: Generator loss: 80.32163719177247, critic loss: -7.413010147571564
Step 7800: Generator loss: 76.99493499755859, critic loss: -12.079633572757244
Step 7850: Generator loss: 81.32430145263672, critic loss: -2.8193510160446174
Step 7900: Generator loss: 80.63022003173828, critic loss: -3.1151746976375576
Step 7950: Generator loss: 75.89005561828613, critic loss: -8.688790566921234
Step 8000: Generator loss: 72.94720428466798, critic loss: -14.186805599212649
Step 8050: Generator loss: 80.84135955810547, critic loss: -11.586392744839191
Step 8100: Generator loss: 79.48079322814941, critic loss: -1.3788062819838527
Step 8150: Generator loss: 72.63796539306641, critic loss: -14.767250993669036
Step 8200: Generator loss: 76.29679145812989, critic loss: -16.04671211397648
Step 8250: Generator loss: 72.60974617004395, critic loss: -17.008654308915133
Step 8300: Generator loss: 75.25621772766114, critic loss: -12.109682399034496
Step 8350: Generator loss: 81.09654647827148, critic loss: -10.706179085254668
Step 8400: Generator loss: 77.28005485534668, critic loss: -4.09239830350876
Step 8450: Generator loss: 83.45014526367187, critic loss: -3.1862959499359125
Step 8500: Generator loss: 80.24715942382812, critic loss: -4.144565615177154
Step 8550: Generator loss: 76.43464157104492, critic loss: -9.53649512773752
Step 8600: Generator loss: 73.67140350341796, critic loss: -15.18680296653509
Step 8650: Generator loss: 75.6114599609375, critic loss: -10.128391755342484
Step 8700: Generator loss: 73.68272163391113, critic loss: -16.97586714470387
Step 8750: Generator loss: 83.1702619934082, critic loss: -0.6609140309095384
Step 8800: Generator loss: 80.41752578735351, critic loss: -4.212692310333251
Step 8850: Generator loss: 71.03237358093261, critic loss: -14.983835175275805
Step 8900: Generator loss: 75.80495880126954, critic loss: -12.667168443322183
Step 8950: Generator loss: 81.14228034973145, critic loss: 2.7472501730918872
Step 9000: Generator loss: 81.20193344116211, critic loss: -3.052738008499146
Step 9050: Generator loss: 73.43904174804688, critic loss: -7.423715700268742
Step 9100: Generator loss: 73.12181861877441, critic loss: -14.306883191585541
Step 9150: Generator loss: 76.89906158447266, critic loss: -13.396733086347583
Step 9200: Generator loss: 75.99712623596191, critic loss: -12.318668732821939
Step 9250: Generator loss: 77.78204513549805, critic loss: -6.621456883490087
Step 9300: Generator loss: 77.82661689758301, critic loss: -11.999425900220869
Step 9350: Generator loss: 81.48483535766601, critic loss: -11.480147421479224
Step 9400: Generator loss: 75.37383903503418, critic loss: -11.605070021390913
Step 9450: Generator loss: 83.24758972167969, critic loss: -1.770111013114451
Step 9500: Generator loss: 75.71745803833008, critic loss: -14.370290687352417
Step 9550: Generator loss: 80.75228134155273, critic loss: -12.244659341961144
Step 9600: Generator loss: 80.36522689819336, critic loss: -9.994889120757579
Step 9650: Generator loss: 79.76879989624024, critic loss: -12.11628355455398
Step 9700: Generator loss: 75.03965270996093, critic loss: -15.582087687492374
Step 9750: Generator loss: 78.26055725097656, critic loss: -9.227732668161394
Step 9800: Generator loss: 86.73946716308593, critic loss: -3.9114915781021113
Step 9850: Generator loss: 77.57634506225585, critic loss: -16.903033419966697
Step 9900: Generator loss: 79.62038360595703, critic loss: -13.387711975812913
Step 9950: Generator loss: 83.48049461364747, critic loss: 0.4212318459749224
Step 10000: Generator loss: 86.0385548400879, critic loss: -3.0202082567214954
Step 10050: Generator loss: 84.96556030273437, critic loss: -3.2984186277389527
Step 10100: Generator loss: 82.55163467407226, critic loss: -5.651416356563568
Step 10150: Generator loss: 72.47459297180175, critic loss: -16.2935069770813
Step 10200: Generator loss: 77.47050117492675, critic loss: -14.219993201971054
Step 10250: Generator loss: 82.40048095703125, critic loss: -9.951535837292676
Step 10300: Generator loss: 78.51686393737793, critic loss: -5.037457182884218
Step 10350: Generator loss: 79.00918548583985, critic loss: -10.983480290770531
Step 10400: Generator loss: 79.10479446411132, critic loss: -11.458023426651957
Step 10450: Generator loss: 79.01952590942383, critic loss: -13.550984252214432
Step 10500: Generator loss: 79.7324333190918, critic loss: -15.04755926167965
Step 10550: Generator loss: 83.25529792785645, critic loss: -10.678096773743627
Step 10600: Generator loss: 78.7729409790039, critic loss: -14.363517974853519
Step 10650: Generator loss: 83.80620101928712, critic loss: -12.4009742795825
Step 10700: Generator loss: 83.44554489135743, critic loss: -5.4632708239853365
Step 10750: Generator loss: 84.38950912475586, critic loss: -4.946207571595907
Step 10800: Generator loss: 84.90599151611327, critic loss: -10.688541789770127
Step 10850: Generator loss: 80.39469886779786, critic loss: -13.391746405303474
Step 10900: Generator loss: 79.68403381347656, critic loss: -14.792330410242082
Step 10950: Generator loss: 84.55435623168945, critic loss: -12.792006389081477
Step 11000: Generator loss: 85.3377848815918, critic loss: -1.002582928955554
Step 11050: Generator loss: 76.42176498413086, critic loss: -16.618346381425855
Step 11100: Generator loss: 82.8500619506836, critic loss: -10.213502784013746
Step 11150: Generator loss: 80.111083984375, critic loss: -16.506468793153765
Step 11200: Generator loss: 81.84511749267578, critic loss: -14.588824108004571
Step 11250: Generator loss: 82.36108421325683, critic loss: -14.826971750736238
Step 11300: Generator loss: 82.89525245666503, critic loss: -14.743118989944467
Step 11350: Generator loss: 78.9609211730957, critic loss: -6.072368972778322
Step 11400: Generator loss: 79.75704879760742, critic loss: -11.66915795624256
Step 11450: Generator loss: 92.73718231201173, critic loss: -8.626956017732619
Step 11500: Generator loss: 76.74110557556152, critic loss: -13.485125755786896
Step 11550: Generator loss: 86.92150177001953, critic loss: -11.96049699956179
Step 11600: Generator loss: 87.94025703430175, critic loss: -7.829241111636162
Step 11650: Generator loss: 78.58638778686523, critic loss: -13.818019400000573
Step 11700: Generator loss: 82.94163925170898, critic loss: -16.088717435359957
Step 11750: Generator loss: 82.20194442749023, critic loss: -13.443735618114472
Step 11800: Generator loss: 77.3590771484375, critic loss: -0.26538432469963885
Step 11850: Generator loss: 87.65712219238281, critic loss: -2.2925723257064816
Step 11900: Generator loss: 86.44266906738281, critic loss: -2.755362086296081
Step 11950: Generator loss: 85.7614064025879, critic loss: -2.9416364326477047
Step 12000: Generator loss: 84.22476821899414, critic loss: -3.100327790260315
Step 12050: Generator loss: 81.84705871582031, critic loss: -3.3889783926010137
Step 12100: Generator loss: 74.62463600158691, critic loss: -9.155223772525787
Step 12150: Generator loss: 83.41003746032715, critic loss: -7.312069640517238
Step 12200: Generator loss: 77.82574188232422, critic loss: -10.063361536026001
Step 12250: Generator loss: 77.09058532714843, critic loss: -15.389594004154203
Step 12300: Generator loss: 85.65135437011719, critic loss: -11.597671725511553
Step 12350: Generator loss: 79.91491325378418, critic loss: -0.8456090040206905
Step 12400: Generator loss: 83.31446044921876, critic loss: -3.5672192862033842
Step 12450: Generator loss: 80.4154541015625, critic loss: -9.493659735798834
Step 12500: Generator loss: 77.22660888671875, critic loss: -11.343838263094426
Step 12550: Generator loss: 76.51863540649414, critic loss: -15.957162732720372
Step 12600: Generator loss: 71.82434341430664, critic loss: -15.232202378749843
Step 12650: Generator loss: 81.55846801757812, critic loss: -12.02893185913563
Step 12700: Generator loss: 77.01351791381836, critic loss: -14.394531373143197
Step 12750: Generator loss: 82.79933059692382, critic loss: -10.995534277558324
Step 12800: Generator loss: 80.33022705078125, critic loss: -7.422801446437835
Step 12850: Generator loss: 77.88019416809082, critic loss: -10.48680070441961
Step 12900: Generator loss: 77.28355583190918, critic loss: -15.062006795048712
Step 12950: Generator loss: 72.02762420654297, critic loss: -18.125201426446434
Step 13000: Generator loss: 78.97825164794922, critic loss: -11.02606911355257
Step 13050: Generator loss: 76.02745002746582, critic loss: -13.242777463912965
Step 13100: Generator loss: 82.44893028259277, critic loss: -10.203380972802634
Step 13150: Generator loss: 80.63447105407715, critic loss: -11.436619911789894
Step 13200: Generator loss: 69.52673934936523, critic loss: -12.998723325610163
Step 13250: Generator loss: 75.26367416381837, critic loss: -12.58380482053757
Step 13300: Generator loss: 78.29216751098633, critic loss: 0.21028297042846839
Step 13350: Generator loss: 70.94842475891113, critic loss: -8.405993442416191
Step 13400: Generator loss: 77.60350791931153, critic loss: -12.201066960632803
Step 13450: Generator loss: 78.38650337219238, critic loss: -13.255251537919046
Step 13500: Generator loss: 72.39071220397949, critic loss: -13.91472595399618
Step 13550: Generator loss: 78.81595336914063, critic loss: -12.717635474145416
Step 13600: Generator loss: 69.23250061035156, critic loss: -15.01334501111508
Step 13650: Generator loss: 77.3666291809082, critic loss: -16.321711009979246
Step 13700: Generator loss: 73.45859939575195, critic loss: -17.17580293393135
Step 13750: Generator loss: 74.07134948730469, critic loss: -14.143001305580142
Step 13800: Generator loss: 68.98319381713867, critic loss: -18.013431072473526
Step 13850: Generator loss: 73.18379371643067, critic loss: -13.245033169150352
Step 13900: Generator loss: 73.70108238220215, critic loss: -15.747089947700497
Step 13950: Generator loss: 71.67143341064452, critic loss: -6.442092946648602
Step 14000: Generator loss: 74.99322380065918, critic loss: -5.310949310302733
Step 14050: Generator loss: 69.55456466674805, critic loss: -7.584069814443586
Step 14100: Generator loss: 68.11343818664551, critic loss: -15.932588892817499
Step 14150: Generator loss: 73.32868095397949, critic loss: -14.538219540774824
Step 14200: Generator loss: 71.54050506591797, critic loss: -6.507004916965961
Step 14250: Generator loss: 73.50055587768554, critic loss: -12.074983437180519
Step 14300: Generator loss: 75.37609176635742, critic loss: -12.215355042934414
Step 14350: Generator loss: 78.41978523254394, critic loss: -13.282461894750588
Step 14400: Generator loss: 69.06725090026856, critic loss: -8.44315874606371
Step 14450: Generator loss: 77.47375007629394, critic loss: -10.59642046368122
Step 14500: Generator loss: 72.112548828125, critic loss: -9.080148652315138
Step 14550: Generator loss: 71.41747200012207, critic loss: -12.610691975355143
Step 14600: Generator loss: 68.53853507995605, critic loss: -14.517420025825501
Step 14650: Generator loss: 71.00217765808105, critic loss: -16.055311642885208
Step 14700: Generator loss: 75.56183944702148, critic loss: -4.261986103117466
Step 14750: Generator loss: 68.21860916137695, critic loss: -14.03696541213989
Step 14800: Generator loss: 71.7959959411621, critic loss: -13.989702057063587
Step 14850: Generator loss: 76.38227409362793, critic loss: -10.939811514139176
Step 14900: Generator loss: 67.81556015014648, critic loss: -15.070325279712678
Step 14950: Generator loss: 71.62906150817871, critic loss: -12.239016912937165
Step 15000: Generator loss: 73.60893783569335, critic loss: -5.476252611890436
Step 15050: Generator loss: 64.43828086853027, critic loss: -11.680644391536712
Step 15100: Generator loss: 66.76135398864746, critic loss: -18.935012437820443
Step 15150: Generator loss: 64.45999412536621, critic loss: -16.77594568133354
Step 15200: Generator loss: 68.81907485961914, critic loss: -16.819265387773513
Step 15250: Generator loss: 71.44663459777831, critic loss: -14.780536164999004
Step 15300: Generator loss: 69.05639785766601, critic loss: -15.773872276782981
Step 15350: Generator loss: 72.00610313415527, critic loss: -12.428475862145426
Step 15400: Generator loss: 66.33817840576172, critic loss: -11.460507846534249
Step 15450: Generator loss: 73.98657371520996, critic loss: -12.046799251675607
Step 15500: Generator loss: 71.14604766845703, critic loss: -14.48868891143799
Step 15550: Generator loss: 72.55676879882813, critic loss: -9.285633412837981
Step 15600: Generator loss: 72.27706947326661, critic loss: -11.433179477930068
Step 15650: Generator loss: 70.14899436950684, critic loss: -14.64919223260879
Step 15700: Generator loss: 70.52759353637695, critic loss: -13.822800672113893
Step 15750: Generator loss: 66.5163092803955, critic loss: -13.497988208055496
Step 15800: Generator loss: 65.68713722229003, critic loss: -13.090139507174491
Step 15850: Generator loss: 68.86076667785645, critic loss: -12.112882311582563
Step 15900: Generator loss: 72.71573020935058, critic loss: -11.739855915784835
Step 15950: Generator loss: 69.23649925231933, critic loss: 1.5610642746686931
Step 16000: Generator loss: 63.27606719970703, critic loss: -6.625546929836272
Step 16050: Generator loss: 65.96758232116699, critic loss: -17.750343059539794
Step 16100: Generator loss: 62.09547576904297, critic loss: -17.4672027888298
Step 16150: Generator loss: 75.76868995666504, critic loss: -12.86666469740868
Step 16200: Generator loss: 64.08884880065918, critic loss: -7.587684287369252
Step 16250: Generator loss: 64.09755882263184, critic loss: -10.57423495966196
Step 16300: Generator loss: 66.86840660095214, critic loss: -3.5877239196300508
Step 16350: Generator loss: 70.60188285827637, critic loss: -7.692209842979907
Step 16400: Generator loss: 62.29129165649414, critic loss: -12.677523095130923
Step 16450: Generator loss: 62.98073165893555, critic loss: -13.866112356960771
Step 16500: Generator loss: 61.780632400512694, critic loss: -6.281874860048294
Step 16550: Generator loss: 62.74724609375, critic loss: -13.956338333368299
Step 16600: Generator loss: 61.48925178527832, critic loss: -16.810678883075717
Step 16650: Generator loss: 52.7329150390625, critic loss: -18.111987345457074
Step 16700: Generator loss: 60.26760322570801, critic loss: -17.837719259858133
Step 16750: Generator loss: 60.27441291809082, critic loss: -14.668455944180492
Step 16800: Generator loss: 64.81710945129394, critic loss: -8.937785160303115
Step 16850: Generator loss: 61.87463485717773, critic loss: -13.174851733446122
Step 16900: Generator loss: 66.52726516723632, critic loss: -17.641908020138743
Step 16950: Generator loss: 63.35795883178711, critic loss: -17.725372922539712
Step 17000: Generator loss: 67.46929817199707, critic loss: -13.343407141447067
Step 17050: Generator loss: 59.79177055358887, critic loss: -16.512492282271385
Step 17100: Generator loss: 66.42052528381348, critic loss: -9.183917128443717
Step 17150: Generator loss: 59.208996353149416, critic loss: -13.243339603602893
Step 17200: Generator loss: 63.88817520141602, critic loss: -13.442776112914084
Step 17250: Generator loss: 69.03452033996582, critic loss: -11.0614826682806
Step 17300: Generator loss: 57.58331108093262, critic loss: -13.529039879202841
Step 17350: Generator loss: 67.6368569946289, critic loss: -11.620229701399802
Step 17400: Generator loss: 60.044710845947264, critic loss: -9.055887681692841
Step 17450: Generator loss: 64.43620628356933, critic loss: -11.920627628207207
Step 17500: Generator loss: 56.046851272583005, critic loss: -22.301562000870714
Step 17550: Generator loss: 62.282958908081056, critic loss: -14.955312865734099
Step 17600: Generator loss: 65.897964553833, critic loss: -6.340100202620029
Step 17650: Generator loss: 58.14865257263184, critic loss: -10.649906709671022
Step 17700: Generator loss: 66.7437523651123, critic loss: -12.996105446338657
Step 17750: Generator loss: 63.59517837524414, critic loss: -12.67168800020218
Step 17800: Generator loss: 65.87414787292481, critic loss: -13.124171116769311
Step 17850: Generator loss: 66.03790901184082, critic loss: -13.294757736086847
Step 17900: Generator loss: 56.273787307739255, critic loss: -18.649981175422667
Step 17950: Generator loss: 68.3505224609375, critic loss: -13.77443748676777
Step 18000: Generator loss: 59.758854675292966, critic loss: -13.528435281991955
Step 18050: Generator loss: 70.61318840026856, critic loss: -12.014795050919052
Step 18100: Generator loss: 62.63155372619629, critic loss: -9.601117482304572
Step 18150: Generator loss: 58.44773849487305, critic loss: -9.71425095164776
Step 18200: Generator loss: 59.341090240478515, critic loss: -16.878086137115954
Step 18250: Generator loss: 57.848808708190916, critic loss: -19.507797758817674
Step 18300: Generator loss: 63.10433967590332, critic loss: -4.3734778246283526
Step 18350: Generator loss: 57.56446601867676, critic loss: -10.789146659135817
Step 18400: Generator loss: 51.76399398803711, critic loss: -15.076736944794657
Step 18450: Generator loss: 57.02366355895996, critic loss: -12.479052137970923
Step 18500: Generator loss: 62.833531875610355, critic loss: -12.99720428943634
Step 18550: Generator loss: 56.588841400146485, critic loss: -14.211519970417026
Step 18600: Generator loss: 61.620222854614255, critic loss: -14.894168957710265
Step 18650: Generator loss: 59.04514297485352, critic loss: -3.9987226614952096
Step 18700: Generator loss: 54.68501613616943, critic loss: -13.798751793980603
Step 18750: Generator loss: 60.477030181884764, critic loss: -13.97087904036045
Step 18800: Generator loss: 59.51054759979248, critic loss: -18.694762709856033
Step 18850: Generator loss: 53.82080192565918, critic loss: -14.210277070969342
Step 18900: Generator loss: 64.43251205444336, critic loss: -13.768319549560543
Step 18950: Generator loss: 56.9788289642334, critic loss: -10.571144456863403
Step 19000: Generator loss: 59.03595703125, critic loss: -12.603199533462528
Step 19050: Generator loss: 60.19775802612305, critic loss: -14.499388661146167
Step 19100: Generator loss: 59.592409973144534, critic loss: -8.202755635917187
Step 19150: Generator loss: 55.50546585083008, critic loss: -16.347688998579976
Step 19200: Generator loss: 61.19869083404541, critic loss: -18.950819284915923
Step 19250: Generator loss: 66.31558391571045, critic loss: -12.890463754177098
Step 19300: Generator loss: 57.29240139007568, critic loss: -18.10998232960701
Step 19350: Generator loss: 59.32999900817871, critic loss: -13.577078444600104
Step 19400: Generator loss: 65.96876052856446, critic loss: -11.8188825455904
Step 19450: Generator loss: 56.72755683898926, critic loss: -14.319641982913016
Step 19500: Generator loss: 57.38858169555664, critic loss: -17.450813733339313
Step 19550: Generator loss: 66.02516723632813, critic loss: -10.693548452854154
Step 19600: Generator loss: 54.7833975982666, critic loss: -13.142704640865325
Step 19650: Generator loss: 57.28132354736328, critic loss: -14.967523851156233
Step 19700: Generator loss: 59.98361915588379, critic loss: -16.183865994155408
Step 19750: Generator loss: 58.18478466033935, critic loss: -15.35918751955032
Step 19800: Generator loss: 63.54511661529541, critic loss: -10.502776491999626
Step 19850: Generator loss: 56.24938293457031, critic loss: -6.75664558506012
Step 19900: Generator loss: 60.40652961730957, critic loss: -13.489446130156516
Ended: 2021-04-23 17:33:03.659309
Elapsed: 0:48:26.572738
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-35-08d8bddbdcc4> in <module>
     14             for _ in range(crit_repeats):
     15                 ### Update critic ###
---> 16                 this_loss, fake = update_critic(crit, crit_opt, gen, gen_opt,
     17                                                 cur_batch_size, z_dim, real)
     18                 mean_iteration_critic_loss += this_loss

<ipython-input-33-696fefd91963> in update_critic(critic, critic_optimizer, generator, generator_optimizer, batch_size, z_dim, real)
      7 
      8     epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
----> 9     gradient = get_gradient(critic, real, fake.detach(), epsilon)
     10     gp = gradient_penalty(gradient)
     11     crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

<ipython-input-16-06dea6615386> in get_gradient(crit, real, fake, epsilon)
     19 
     20     # Take the gradient of the scores with respect to the images
---> 21     gradient = torch.autograd.grad(
     22         # Note: You need to take the gradient of outputs with respect to inputs.
     23         #### START CODE HERE ####

~/.conda/envs/neurotic-pytorch/lib/python3.9/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
    221         retain_graph = create_graph
    222 
--> 223     return Variable._execution_engine.run_backward(
    224         outputs, grad_outputs_, retain_graph, create_graph,
    225         inputs, allow_unused, accumulate_grad=False)

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.92 GiB total capacity; 7.13 GiB already allocated; 24.62 MiB free; 7.13 GiB reserved in total by PyTorch)

End

  • Arjovsky M, Chintala S, Bottou L. Wasserstein generative adversarial networks. International conference on machine learning 2017 Jul 17 (pp. 214-223). PMLR. (archiv.org)
  • Gulrajani I, Ahmed F, Arjovsky M, Dumoulin V, Courville A. Improved training of wasserstein gans. arXiv preprint arXiv:1704.00028. 2017 Mar 31. (archiv.org)