CNN GAN

Deep Convolutional GAN (DCGAN)

We're going to build a Generative Adversarial Network to generate handwritten digits. Instead of using fully-connected layers we'll use Convolutional layers.

Here are the main features of a DCGAN.

  • Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
  • Use BatchNorm in both the generator and the discriminator.
  • Remove fully connected hidden layers for deeper architectures.
  • ReLU activation in generator for all layers except for the output, which uses Tanh.
  • Use LeakyReLU activation in the discriminator for all layers.

Imports

# python
from collections import namedtuple
from functools import partial
from pathlib import Path

# conda
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

import holoviews
import hvplot.pandas
import matplotlib.pyplot as pyplot
import pandas
import torch
# my stuff
from graeae import EmbedHoloviews, Timer

Set Up

The Random Seed

torch.manual_seed(0)

Plotting and Timing

TIMER = Timer()
slug = "cnn-gan"

Embed = partial(EmbedHoloviews, folder_path=f"files/posts/gans/{slug}")

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )

Helper Functions

A Plotter

def plot_image(image: torch.Tensor,
                filename: str,
                title: str,
                num_images: int=25,
                size: tuple=(1, 28, 28),
                folder: str=f"files/posts/gans/{slug}/") -> None:
    """Plot the image and save it

    Args:
     image: the tensor with the image to plot
     filename: name for the final image file
     title: title to put on top of the image
     num_images: how many images to put in the composite image
     size: the size for the image
     folder: sub-folder to save the file in
    """
    unflattened_image = image.detach().cpu().view(-1, *size)
    image_grid = make_grid(unflattened_image[:num_images], nrow=5)

    pyplot.title(title)
    pyplot.grid(False)
    pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())

    pyplot.tick_params(bottom=False, top=False, labelbottom=False,
                       right=False, left=False, labelleft=False)
    pyplot.savefig(folder + filename)
    print(f"[[file:{filename}]]")
    return

A Noise Maker

def make_some_noise(n_samples: int, z_dim: int, device: str="cpu") -> torch.Tensor:
    """create noise vectors

    creates 
    Args:
       n_samples: the number of samples to generate, a scalar
       z_dim: the dimension of the noise vector, a scalar
       device: the device type (cpu or cuda)

    Returns:
     tensor with random numbers from the normal distribution.
    """

    return torch.randn(n_samples, z_dim, device=device)

Middle

The Generator

The first component you will make is the generator. You may notice that instead of passing in the image dimension, you will pass the number of image channels to the generator. This is because with DCGAN, you use convolutions which don’t depend on the number of pixels on an image. However, the number of channels is important to determine the size of the filters.

You will build a generator using 4 layers (3 hidden layers + 1 output layer). As before, you will need to write a function to create a single block for the generator's neural network. From the paper:

  • [u]se batchnorm in both the generator and the discriminator"
  • [u]se ReLU activation in generator for all layers except for the output, which uses Tanh.

Since in DCGAN the activation function will be different for the output layer, you will need to check what layer is being created.

At the end of the generator class, you are given a forward pass function that takes in a noise vector and generates an image of the output dimension using your neural network. You are also given a function to create a noise vector. These functions are the same as the ones from the last assignment.

See also:

The Generator Class

class Generator(nn.Module):
    """The DCGAN Generator

    Args:
       z_dim: the dimension of the noise vector
       im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is your default)
       hidden_dim: the inner dimension,
    """
    def __init__(self, z_dim: int=10, im_chan: int=1, hidden_dim: int=64):
        super().__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels: int, output_channels: int,
                       kernel_size: int=3, stride: int=2,
                       final_layer: bool=False) -> nn.Sequential:
        """Creates a block for the generator (sub sequence)

       The parts
        - a transposed convolution
        - a batchnorm (except for in the last layer)
        - an activation.

       Args:
           input_channels: how many channels the input feature representation has
           output_channels: how many channels the output feature representation should have
           kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
           stride: the stride of the convolution
           final_layer: a boolean, true if it is the final layer and false otherwise 
                     (affects activation and batchnorm)

       Returns:
        the sub-sequence of layers
       """
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU()
            )
        else: # Final Layer
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh()
            )

    def unsqueeze_noise(self, noise: torch.Tensor) -> torch.Tensor:
        """transforms the noize tensor

       Args:
           noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns:
        copy of noise with width and height = 1 and channels = z_dim.
       """
        return noise.view(len(noise), self.z_dim, 1, 1)

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        """complete a forward pass of the generator: Given a noise tensor, 

       Args:
        noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns:
        generated images.
       """
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

Setup Testing

gen = Generator()
num_test = 100

# Test the hidden block
test_hidden_noise = make_some_noise(num_test, gen.z_dim)
test_hidden_block = gen.make_gen_block(10, 20, kernel_size=4, stride=1)
test_uns_noise = gen.unsqueeze_noise(test_hidden_noise)
hidden_output = test_hidden_block(test_uns_noise)

# Check that it works with other strides
test_hidden_block_stride = gen.make_gen_block(20, 20, kernel_size=4, stride=2)

test_final_noise = make_some_noise(num_test, gen.z_dim) * 20
test_final_block = gen.make_gen_block(10, 20, final_layer=True)
test_final_uns_noise = gen.unsqueeze_noise(test_final_noise)
final_output = test_final_block(test_final_uns_noise)

# Test the whole thing:
test_gen_noise = make_some_noise(num_test, gen.z_dim)
test_uns_gen_noise = gen.unsqueeze_noise(test_gen_noise)
gen_output = gen(test_uns_gen_noise)

Unit Tests

assert tuple(hidden_output.shape) == (num_test, 20, 4, 4)
assert hidden_output.max() > 1
assert hidden_output.min() == 0
assert hidden_output.std() > 0.2
assert hidden_output.std() < 1
assert hidden_output.std() > 0.5

assert tuple(test_hidden_block_stride(hidden_output).shape) == (num_test, 20, 10, 10)

assert final_output.max().item() == 1
assert final_output.min().item() == -1

assert tuple(gen_output.shape) == (num_test, 1, 28, 28)
assert gen_output.std() > 0.5
assert gen_output.std() < 0.8
print("Success!")

The Discriminator

The second component you need to create is the discriminator.

You will use 3 layers in your discriminator's neural network. Like with the generator, you will need to create the method to create a single neural network block for the discriminator.

From the paper:

  • [u]se LeakyReLU activation in the discriminator for all layers.
  • For the LeakyReLUs, "the slope of the leak was set to 0.2" in DCGAN.

See Also:

The Discriminator Class

class Discriminator(nn.Module):
    """The DCGAN Discriminator

    Args:
     im_chan: the number of channels in the images, fitted for the dataset used
             (MNIST is black-and-white, so 1 channel is the default)
     hidden_dim: the inner dimension,
    """
    def __init__(self, im_chan: int=1, hidden_dim: int=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )
        return

    def make_disc_block(self, input_channels: int, output_channels: int,
                        kernel_size: int=4, stride: int=2,
                        final_layer: bool=False) -> nn.Sequential:
        """Make a sub-block of layers for the discriminator

        - a convolution
        - a batchnorm (except for in the last layer)
        - an activation.

       Args:
         input_channels: how many channels the input feature representation has
         output_channels: how many channels the output feature representation should have
         kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
         stride: the stride of the convolution
         final_layer: if true it is the final layer and otherwise not
                     (affects activation and batchnorm)
       """        
        # Build the neural block
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2)
            )
        else: # Final Layer
            return nn.Sequential(
                #### START CODE HERE #### #
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                #### END CODE HERE ####
            )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Complete a forward pass of the discriminator

       Args:
         image: a flattened image tensor with dimension (im_dim)

       Returns:
        a 1-dimension tensor representing fake/real.
       """
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

Set Up Testing

num_test = 100

gen = Generator()
disc = Discriminator()
test_images = gen(make_some_noise(num_test, gen.z_dim))

# Test the hidden block
test_hidden_block = disc.make_disc_block(1, 5, kernel_size=6, stride=3)
hidden_output = test_hidden_block(test_images)

# Test the final block
test_final_block = disc.make_disc_block(1, 10, kernel_size=2, stride=5, final_layer=True)
final_output = test_final_block(test_images)

# Test the whole thing:
disc_output = disc(test_images)

Unit Testing

  • The Hidden Block
    assert tuple(hidden_output.shape) == (num_test, 5, 8, 8)
    # Because of the LeakyReLU slope
    assert -hidden_output.min() / hidden_output.max() > 0.15
    assert -hidden_output.min() / hidden_output.max() < 0.25
    assert hidden_output.std() > 0.5
    assert hidden_output.std() < 1
    
  • The Final Block
    assert tuple(final_output.shape) == (num_test, 10, 6, 6)
    assert final_output.max() > 1.0
    assert final_output.min() < -1.0
    assert final_output.std() > 0.3
    assert final_output.std() < 0.6
    
  • The Whole Thing
    assert tuple(disc_output.shape) == (num_test, 1)
    assert disc_output.std() > 0.25
    assert disc_output.std() < 0.5
    print("Success!")
    

Training The Model

Remember that these are your parameters:

  • criterion: the loss function
  • n_epochs: the number of times you iterate through the entire dataset when training
  • z_dim: the dimension of the noise vector
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • beta_1, beta_2: the momentum term
  • device: the device type

Set Up The Data

criterion = nn.BCEWithLogitsLoss()
z_dim = 64
batch_size = 128
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002

# These parameters control the optimizer's momentum, which you can read more about here:
# https://distill.pub/2017/momentum/ but you don’t need to worry about it for this course!
beta_1 = 0.5 
beta_2 = 0.999
device = 'cuda'

# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

path = Path("~/pytorch-data/MNIST").expanduser()
dataloader = DataLoader(
    MNIST(path, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Set Up the GAN

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

A Weight Initializer

def initial_weights(m):
    """Initialize the weights to the normal distribution

     - mean 0
     - standard deviation 0.02

    Args:
     m: layer whose weights to initialize
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
    return
gen = gen.apply(initial_weights)
disc = disc.apply(initial_weights)

Train it

For each epoch, you will process the entire dataset in batches. For every batch, you will update the discriminator and generator. Then, you can see DCGAN's results!

Here's roughly the progression you should be expecting. On GPU this takes about 30 seconds per thousand steps. On CPU, this can take about 8 hours per thousand steps. You might notice that in the image of Step 5000, the generator is disproprotionately producing things that look like ones. If the discriminator didn't learn to detect this imbalance quickly enough, then the generator could just produce more ones. As a result, it may have ended up tricking the discriminator so well that there would be no more improvement, known as mode collapse.

n_epochs = 100
cur_step = 0
display_step = 1000
mean_generator_loss = 0
mean_discriminator_loss = 0
generator_losses = []
discriminator_losses = []
steps = []

best_loss = float("inf")
best_step = 0
best_path = Path("~/models/gans/mnist-dcgan/best_model.pth").expanduser()

with TIMER:
    for epoch in range(n_epochs):
        # Dataloader returns the batches
        for real, _ in dataloader:
            cur_batch_size = len(real)
            real = real.to(device)

            ## Update discriminator ##
            disc_opt.zero_grad()
            fake_noise = make_some_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            disc_fake_pred = disc(fake.detach())
            disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
            disc_real_pred = disc(real)
            disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2

            # Keep track of the average discriminator loss
            mean_discriminator_loss += disc_loss.item() / display_step
            # Update gradients
            disc_loss.backward(retain_graph=True)
            # Update optimizer
            disc_opt.step()

            ## Update generator ##
            gen_opt.zero_grad()
            fake_noise_2 = make_some_noise(cur_batch_size, z_dim, device=device)
            fake_2 = gen(fake_noise_2)
            disc_fake_pred = disc(fake_2)
            gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
            gen_loss.backward()
            gen_opt.step()

            # Keep track of the average generator loss
            mean_generator_loss += gen_loss.item() / display_step
            if mean_generator_loss < best_loss:
                best_loss, best_step = mean_generator_loss, cur_step
                with best_path.open("wb") as writer:
                    torch.save(gen, writer)
            ## Visualization code ##
            if cur_step % display_step == 0 and cur_step > 0:
                print(f"Epoch {epoch}, step {cur_step}: Generator loss:"
                        f" {mean_generator_loss}, discriminator loss:"
                        f" {mean_discriminator_loss}")

                steps.append(cur_step)
                generator_losses.append(mean_generator_loss)
                discriminator_losses.append(mean_discriminator_loss)

                mean_generator_loss = 0
                mean_discriminator_loss = 0
            cur_step += 1
Started: 2021-04-21 12:45:12.452739
Epoch 2, step 1000: Generator loss: 1.2671969079673289, discriminator loss: 0.43014343224465823
Epoch 4, step 2000: Generator loss: 1.1353899730443968, discriminator loss: 0.5306872705817226
Epoch 6, step 3000: Generator loss: 0.8764803466945883, discriminator loss: 0.611450107574464
Epoch 8, step 4000: Generator loss: 0.7747784045338618, discriminator loss: 0.6631499938964849
Epoch 10, step 5000: Generator loss: 0.7640163034200661, discriminator loss: 0.6734729865789411
Epoch 12, step 6000: Generator loss: 0.7452541967928404, discriminator loss: 0.6805261079072958
Epoch 14, step 7000: Generator loss: 0.7337032879889016, discriminator loss: 0.6874966211915009
Epoch 17, step 8000: Generator loss: 0.7245009585618979, discriminator loss: 0.6908933531045917
Epoch 19, step 9000: Generator loss: 0.7180560626983646, discriminator loss: 0.6936621717810626
Epoch 21, step 10000: Generator loss: 0.7115822317004211, discriminator loss: 0.695760274052621
Epoch 23, step 11000: Generator loss: 0.7090291924774644, discriminator loss: 0.6962701203227039
Epoch 25, step 12000: Generator loss: 0.7059894913136957, discriminator loss: 0.6973492541313167
Epoch 27, step 13000: Generator loss: 0.7030480077862743, discriminator loss: 0.6978999735713001
Epoch 29, step 14000: Generator loss: 0.7028095332086096, discriminator loss: 0.6974007876515396
Epoch 31, step 15000: Generator loss: 0.7027116653919212, discriminator loss: 0.6965595571994787
Epoch 34, step 16000: Generator loss: 0.7005282629728309, discriminator loss: 0.6962912415862079
Epoch 36, step 17000: Generator loss: 0.7007142878770828, discriminator loss: 0.6961965024471283
Epoch 38, step 18000: Generator loss: 0.699474583208561, discriminator loss: 0.6952810400128371
Epoch 40, step 19000: Generator loss: 0.6989677719473828, discriminator loss: 0.6954642050266268
Epoch 42, step 20000: Generator loss: 0.6977452509403238, discriminator loss: 0.695180906951427
Epoch 44, step 21000: Generator loss: 0.6973587237596515, discriminator loss: 0.6950308464765543
Epoch 46, step 22000: Generator loss: 0.6960379970669743, discriminator loss: 0.6949119175076485
Epoch 49, step 23000: Generator loss: 0.6957966268062581, discriminator loss: 0.6948324624896048
Epoch 51, step 24000: Generator loss: 0.6958502059578898, discriminator loss: 0.6945331234931943
Epoch 53, step 25000: Generator loss: 0.6954856168627734, discriminator loss: 0.6943869084119801
Epoch 55, step 26000: Generator loss: 0.6957543395757682, discriminator loss: 0.694317172288894
Epoch 57, step 27000: Generator loss: 0.6947923063635825, discriminator loss: 0.694082073867321
Epoch 59, step 28000: Generator loss: 0.6945026598572728, discriminator loss: 0.6939926172494871
Epoch 61, step 29000: Generator loss: 0.6947789136767392, discriminator loss: 0.6938506522774704
Epoch 63, step 30000: Generator loss: 0.6946699734926227, discriminator loss: 0.6937169924378406
Epoch 66, step 31000: Generator loss: 0.6944284628629694, discriminator loss: 0.6936815274357805
Epoch 68, step 32000: Generator loss: 0.6940396347641948, discriminator loss: 0.6935891906023032
Epoch 70, step 33000: Generator loss: 0.6946771386265761, discriminator loss: 0.6937210547327995
Epoch 72, step 34000: Generator loss: 0.693429798424244, discriminator loss: 0.6937174627780922
Epoch 74, step 35000: Generator loss: 0.6937471128702157, discriminator loss: 0.6935204346776015
Epoch 76, step 36000: Generator loss: 0.6938841561675072, discriminator loss: 0.6934832554459566
Epoch 78, step 37000: Generator loss: 0.6934520475268362, discriminator loss: 0.6934578058719627
Epoch 81, step 38000: Generator loss: 0.6936635475754732, discriminator loss: 0.6934186050295835
Epoch 83, step 39000: Generator loss: 0.6936795052289972, discriminator loss: 0.6935187472105031
Epoch 85, step 40000: Generator loss: 0.6933113215565679, discriminator loss: 0.6933534587025645
Epoch 87, step 41000: Generator loss: 0.6934976277351385, discriminator loss: 0.6933284662365923
Epoch 89, step 42000: Generator loss: 0.6933313971757892, discriminator loss: 0.693348657488824
Epoch 91, step 43000: Generator loss: 0.6937436528205883, discriminator loss: 0.6933502901792529
Epoch 93, step 44000: Generator loss: 0.6943431540131578, discriminator loss: 0.6933887023925772
Epoch 95, step 45000: Generator loss: 0.6938722513914105, discriminator loss: 0.6932663491368296
Epoch 98, step 46000: Generator loss: 0.6933276618123067, discriminator loss: 0.6934270900487906
Ended: 2021-04-21 13:06:00.256725
Elapsed: 0:20:47.803986

Looking at the Final model.

fake_noise = make_some_noise(cur_batch_size, z_dim, device=device)

best_model = torch.load(best_path)
fake = best_model(fake_noise)
plot_image(image=fake, filename="fake_digits.png", title="Fake Digits")

fake_digits.png

plot_image(real, filename="real_digits.png", title="Real Digits")

real_digits.png

plotting = pandas.DataFrame.from_dict({
    "Step": steps,
    "Generator Loss": generator_losses,
    "Discriminator Loss": discriminator_losses
})

best = plotting.iloc[plotting["Generator Loss"].argmin()]
best_line = holoviews.VLine(best.Step)
gen_plot = plotting.hvplot(x="Step", y="Generator Loss", color=PLOT.blue)
disc_plot = plotting.hvplot(x="Step", y="Discriminator Loss", color=PLOT.red)

plot = (gen_plot * disc_plot * best_line).opts(title="Training Losses",
                                               height=PLOT.height,
                                               width=PLOT.width,
                                               ylabel="Loss",
                                               fontscale=PLOT.fontscale)
output = Embed(plot=plot, file_name="losses")()
print(output)

Figure Missing

End

Sources

  • Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434. 2015 Nov 19. (PDF)

PyTorch Linear Regression

Table of Contents

Beginning

Imports

# python
from collections import namedtuple
from functools import partial

# pypi
from torch import nn
from torch.utils.data import Dataset, DataLoader

import hvplot.pandas
import numpy
import pandas

# local stuff
from graeae import EmbedHoloviews

Set Up

random_generator = numpy.random.default_rng(seed=2021)
slug = "pytorch-linear-regression"
Embed = partial(EmbedHoloviews, folder_path=f"files/posts/pytorch/{slug}")

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )
def sample(start: float, stop: float, shape: tuple, uniform: bool=True) -> numpy.ndarray:
    """Create a random sample

    Args:
     start: lowest allowed value
     stop: highest allowed value
     shape: shape for the final array (just an int for single values)
     uniform: use the uniform distribution instead of the standard normal
    """
    if uniform:
        return (stop - start) * random_generator.random(shape) + start
    return (stop - start) * random_generator.standard_normal(shape) + start

Middle

SAMPLES = 200
X_RANGE = 5
x_values = sample(-X_RANGE, X_RANGE, SAMPLES)
SLOPE = sample(-5, 5, 1)
INTERCEPT = sample(-5, 5, 1)
noise = sample(-2, 2, SAMPLES, uniform=False)
y_values = SLOPE * x_values + INTERCEPT + noise
data_frame = pandas.DataFrame.from_dict(dict(X=x_values, Y=y_values))
first, last = x_values.min(), x_values.max()
line_frame = pandas.DataFrame.from_dict(
    dict(X=[first, last],
         Y=[SLOPE * first + INTERCEPT,
            SLOPE * last + INTERCEPT]))
line_plot = line_frame.hvplot(x="X", y="Y", color=PLOT.blue)
data_plot = data_frame.hvplot.scatter(x="X", y="Y", title="Sample Data",
                                      color=PLOT.tan)
plot = (data_plot * line_plot).opts(
    height=PLOT.height,
    width=PLOT.width,
    fontscale=PLOT.fontscale
)
output = Embed(plot=plot, file_name="sample_data")()

Figure Missing

class XY(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        return

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        return {"x": self.x[index], "y": self.y[index]}
dataset = XY(x_values, y_values)
loader = DataLoader(dataset, batch_size=4)
model = nn.Linear(1, 1)

MNIST GAN

Note: The current version of pytorch (1.8.1) causes a Segmentation Fault in my nvidia-docker container (running CUDA 11, python 3.9, and Ubuntu 20.04). The fault comes at different points in the code depending on what I do - sometimes it's the backward's propagation, sometimes it's the pytorch binary that causes it, sometimes it's the libcuda binary… trying to debug it is probably beyond me so to get this working I had to go to the previous version of pytorch (1.7.1).

Update: The previous error happened when I was using pip. I switched to using conda and that seems to have fixed it. One of the downsides to using conda is that it seems to download all the binary dependencies and it takes much longer to install everything than it does with pip, maybe because I don't have a really fat internet pipe. Oh, well.

Beginning

Imports

# python
from collections import namedtuple
from functools import partial

# from pypi
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

import hvplot.pandas
import matplotlib.pyplot as pyplot
import pandas
import torch

# local code
from graeae import EmbedHoloviews, Timer

Some Setup

First we'll set the manual seed to make this reproducible.

torch.manual_seed(0)

This is a convenience object to time the training.

TIMER = Timer()

This is for plotting.

slug = "mnist-gan"

Embed = partial(EmbedHoloviews, folder_path=f"files/posts/gans/{slug}")

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )

Middle

The MNIST Dataset

The training images we will be using are from a dataset called MNIST. The dataset contains 60,000 images of handwritten digits, from 0 to 9.

The images are 28 pixels x 28 pixels in size. The small size of its images makes MNIST ideal for simple training. Additionally, these images are also in black-and-white so only one dimension, or "color channel", is needed to represent them. Pytorch has a version of it ready-made for their system so we'll use theirs.

The Generator

The first step is to build the generator component.

We'll start by creating a function to make a single layer/block for the generator's neural network. Each block should include a linear transformation (\(y=xA^T + b\)) to the input to another shape, batch normalization for stabilization, and finally a non-linear activation function (ReLU in this case).

def generator_block(input_features: int, output_features: int) -> nn.Sequential:
    """
    Creates a block of the generator's neural network

    Args:
      input_features: the dimension of the input vector
      output_features: the dimension of the output vector

    Returns:
       a generator neural network layer, with a linear transformation 
         followed by a batch normalization and then a relu activation
    """
    return nn.Sequential(
        nn.Linear(input_features, output_features),
        nn.BatchNorm1d(output_features),
        nn.ReLU(inplace=True),
    )

Verify the generator block function

def test_gen_block(in_features: int, out_features: int,
                   test_rows: int=1000) -> None:
    """Test the generator block creator

    Args:
     in_features: number of features for the block input
     out_features: the final number of features for it to output
     test_rows: how many rows to put in the test Tensor

    Raises:
     AssertionError: something isn't right
    """
    block = generator_block(in_features, out_features)

    # Check the three parts
    assert len(block) == 3
    assert type(block[0]) == nn.Linear
    assert type(block[1]) == nn.BatchNorm1d
    assert type(block[2]) == nn.ReLU

    # Check the output shape
    test_output = block(torch.randn(test_rows, in_features))
    assert tuple(test_output.shape) == (test_rows, out_features)

    # check the normalization
    assert 0.65 > test_output.std() > 0.55
    return

test_gen_block(25, 12)
test_gen_block(15, 28)

Building the Generator Class

Now that we have the block-builder we can define our Generator network. It's going to contain a sequence of blocks output by our block-building function and a final two layers that use the linear transformation again, but don't apply normalization and use a Sigmoid Function instead of the ReLU. Each block will have an output double that of the previous one.

generator.png

class Generator(nn.Module):
    """Generator Class

    Args:
      input_dimension: the dimension of the noise vector
      image_dimension: the dimension of the images, fitted for the dataset used
        (MNIST images are 28 x 28 = 784 so that is the default)
      hidden_dimension: the initial hidden-layer dimension
    """
    def __init__(self, input_dimension: int=10, image_dimension: int=784,
                 hidden_dimension: int=128):
        super().__init__()

        self.generator = nn.Sequential(
            get_generator_block(input_dimension, hidden_dimension),
            get_generator_block(hidden_dimension, hidden_dimension * 2),
            get_generator_block(hidden_dimension * 2, hidden_dimension * 4),
            get_generator_block(hidden_dimension * 4, hidden_dimension * 8),
            nn.Linear(hidden_dimension * 8, image_dimension),
            nn.Sigmoid()
        )
        return

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        """
       Method for a forward pass of the generator

       Args:
        noise: a noise tensor with dimensions (n_samples, z_dim)

       Returns: 
        generated images.
       """
        return self.generator(noise)

Verify the Generator Class

def test_generator(z_dim: int, im_dim: int, hidden_dim: int, 
                   num_test: int=10000) -> None:
    """Test the Generator Class

    Args:
     z_dim: the size of the input
     im_dim: the size of the image
     hidden_dim: the size of the initial hidden layer

    Raises:
     AssertionError: something is wrong
    """
    gen = Generator(z_dim, im_dim, hidden_dim).generator

    # Check there are six modules in the sequential part
    assert len(gen) == 6
    test_input = torch.randn(num_test, z_dim)
    test_output = gen(test_input)

    # Check that the output shape is correct
    assert tuple(test_output.shape) == (num_test, im_dim)

    # Chechk the output
    assert 0 < test_output.max() < 1, "Make sure to use a sigmoid"
    assert test_output.min() < 0.5, "Don't use a block in your solution"
    assert 0.15 > test_output.std() > 0.05, "Don't use batchnorm here"
    return

test_generator(5, 10, 20)
test_generator(20, 8, 24)

Noise

To be able to use the generator, we will need to be able to create noise vectors. The noise vector z has the important role of making sure the images generated from the same class don't all look the same – think of it as a random seed. You will generate it randomly using PyTorch by sampling random numbers from the normal distribution. Since multiple images will be processed per pass, you will generate all the noise vectors at once.

Note that whenever you create a new tensor using torch.ones, torch.zeros, or torch.randn, you either need to create it on the target device, e.g. torch.ones(3, 3, device=device), or move it onto the target device using torch.ones(3, 3).to(device). You do not need to do this if you're creating a tensor by manipulating another tensor or by using a variation that defaults the device to the input, such as torch.ones_like. In general, use torch.ones_like and torch.zeros_like instead of torch.ones or torch.zeros where possible.

def get_noise(n_samples: int, z_dim: int, device='cuda') -> torch.Tensor:
    """create noise vectors

    Args:
       n_samples: the number of samples to generate, a scalar
       z_dim: the dimension of the noise vector, a scalar
       device: the device type
    """
    return torch.randn(n_samples, z_dim, device=device)

Verify the noise vector function

def test_get_noise(n_samples, z_dim, device='cpu'):
    noise = get_noise(n_samples, z_dim, device)

    # Make sure a normal distribution was used
    assert tuple(noise.shape) == (n_samples, z_dim)
    assert torch.abs(noise.std() - torch.tensor(1.0)) < 0.01
    assert str(noise.device).startswith(device)

test_get_noise(1000, 32)

The Discriminator

The second component that you need to construct is the discriminator. As with the generator component, you will start by creating a function that builds a neural network block for the discriminator.

Note: You use leaky ReLUs to prevent the "dying ReLU" problem, which refers to the phenomenon where the parameters stop changing due to consistently negative values passed to a ReLU, which result in a zero gradient.

def get_discriminator_block(input_dim: int, output_dim: int,
                            negative_slope: float=0.2) -> nn.Sequential:
    """Create the Discriminator block

    Args:
      input_dim: the dimension of the input vector, a scalar
      output_dim: the dimension of the output vector, a scalar
      negative_slope: angle for the negative slope

    Returns:
       a discriminator neural network layer, with a linear transformation 
         followed by an nn.LeakyReLU activation with negative slope of 0.2 
    """
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(negative_slope=0.2)
    )

Verify the discriminator block function

def test_disc_block(in_features, out_features, num_test=10000):
    block = get_discriminator_block(in_features, out_features)

    # Check there are two parts
    assert len(block) == 2
    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)

    # Check that the shape is right
    assert tuple(test_output.shape) == (num_test, out_features)

    # Check that the LeakyReLU slope is about 0.2
    assert -test_output.min() / test_output.max() > 0.1
    assert -test_output.min() / test_output.max() < 0.3
    assert test_output.std() > 0.3
    assert test_output.std() < 0.5

test_disc_block(25, 12)
test_disc_block(15, 28)

The Discriminator Class

The discriminator class holds 2 values:

  • The image dimension
  • The hidden dimension

The discriminator will build a neural network with 4 layers. It will start with the image tensor and transform it until it returns a single number (1-dimension tensor) output. This output classifies whether an image is fake or real. Note that you do not need a sigmoid after the output layer since it is included in the loss function. Finally, to use your discrimator's neural network you are given a forward pass function that takes in an image tensor to be classified.

class Discriminator(nn.Module):
    """The Discriminator Class

    Args:
       im_dim: the dimension of the images, fitted for the dataset used, a scalar
           (MNIST images are 28x28 = 784 so that is your default)
       hidden_dim: the inner dimension, a scalar
    """
    def __init__(self, im_dim: int=784, hidden_dim: int=128):
        super().__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """forward pass of the discriminator

       Args:
           image: a flattened image tensor with dimension (im_dim)

       Returns a 1-dimension tensor representing fake/real.
       """
        return self.disc(image)
  • Verify the discriminator class
    def test_discriminator(z_dim, hidden_dim, num_test=100):
    
        disc = Discriminator(z_dim, hidden_dim).disc
    
        # Check there are three parts
        assert len(disc) == 4
    
        # Check the linear layer is correct
        test_input = torch.randn(num_test, z_dim)
        test_output = disc(test_input)
        assert tuple(test_output.shape) == (num_test, 1)
    
        # Don't use a block
        assert not isinstance(disc[-1], nn.Sequential)
    
    test_discriminator(5, 10)
    test_discriminator(20, 8)
    

Training

First, you will set your parameters:

  • criterion: the loss function (BCEWithLogitsLoss
  • n_epochs: the number of times you iterate through the entire dataset when training
  • z_dim: the dimension of the noise vector
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • device: the device type, here using a GPU (which runs CUDA), not CPU

Next, you will load the MNIST dataset as tensors using a dataloader.

Set your parameters

criterion = nn.BCEWithLogitsLoss()
z_dim = 64
batch_size = 128
lr = 0.00001

Load MNIST dataset as tensors

dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

Now, you can initialize your generator, discriminator, and optimizers. Note that each optimizer only takes the parameters of one particular model, since we want each optimizer to optimize only one of the models.

device = "cuda"
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

Before you train your GAN, you will need to create functions to calculate the discriminator's loss and the generator's loss. This is how the discriminator and generator will know how they are doing and improve themselves. Since the generator is needed when calculating the discriminator's loss, you will need to call .detach() on the generator result to ensure that only the discriminator is updated!

Remember that you have already defined a loss function earlier (criterion) and you are encouraged to use torch.ones_like and torch.zeros_like instead of torch.ones or torch.zeros. If you use torch.ones or torch.zeros, you'll need to pass device=device to them.

def get_disc_loss(gen: Generator, disc: Discriminator,
                  criterion: nn.BCEWithLogitsLoss,
                  real: torch.Tensor,
                  num_images: int, z_dim: int, 
                  device: str="cuda"):
    """
    Get the loss of the discriminator given inputs.

    Args:
       gen: the generator model, which returns an image given z-dimensional noise
       disc: the discriminator model, which returns a single-dimensional prediction of real/fake
       criterion: the loss function, which should be used to compare 
              the discriminator's predictions to the ground truth reality of the images 
              (e.g. fake = 0, real = 1)
       real: a batch of real images
       num_images: the number of images the generator should produce, 
               which is also the length of the real images
       z_dim: the dimension of the noise vector, a scalar
       device: the device type

    Returns:
       disc_loss: a torch scalar loss value for the current batch
    """
    noise = torch.randn(num_images, z_dim, device=device)
    fakes = gen(noise).detach()

    fake_prediction = disc(fakes)
    fake_loss = criterion(fake_prediction, torch.zeros_like(fake_prediction))

    real_prediction = disc(real)
    real_loss = criterion(real_prediction, torch.ones_like(real_prediction))
    disc_loss = (fake_loss + real_loss)/2
    return disc_loss
def test_disc_reasonable(num_images=10):
    # Don't use explicit casts to cuda - use the device argument
    import inspect, re
    lines = inspect.getsource(get_disc_loss)
    assert (re.search(r"to\(.cuda.\)", lines)) is None
    assert (re.search(r"\.cuda\(\)", lines)) is None

    z_dim = 64
    gen = torch.zeros_like
    disc = lambda x: x.mean(1)[:, None]
    criterion = torch.mul # Multiply
    real = torch.ones(num_images, z_dim)
    disc_loss = get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(disc_loss.mean() - 0.5) < 1e-5)

    gen = torch.ones_like
    criterion = torch.mul # Multiply
    real = torch.zeros(num_images, z_dim)
    assert torch.all(torch.abs(get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu')) < 1e-5)

    gen = lambda x: torch.ones(num_images, 10)
    disc = lambda x: x.mean(1)[:, None] + 10
    criterion = torch.mul # Multiply
    real = torch.zeros(num_images, 10)
    assert torch.all(torch.abs(get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu').mean() - 5) < 1e-5)

    gen = torch.ones_like
    disc = nn.Linear(64, 1, bias=False)
    real = torch.ones(num_images, 64) * 0.5
    disc.weight.data = torch.ones_like(disc.weight.data) * 0.5
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    criterion = lambda x, y: torch.sum(x) + torch.sum(y)
    disc_loss = get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu').mean()
    disc_loss.backward()
    assert torch.isclose(torch.abs(disc.weight.grad.mean() - 11.25), torch.tensor(3.75))
    return

test_disc_reasonable()
def test_disc_loss(max_tests = 10):
    z_dim = 64
    gen = Generator(z_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc = Discriminator().to(device) 
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    num_steps = 0
    for real, _ in dataloader:
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradient before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
        assert (disc_loss - 0.68).abs() < 0.05, disc_loss

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Check that they detached correctly
        assert gen.gen[0][0].weight.grad is None

        # Update optimizer
        old_weight = disc.disc[0][0].weight.data.clone()
        disc_opt.step()
        new_weight = disc.disc[0][0].weight.data

        # Check that some discriminator weights changed
        assert not torch.all(torch.eq(old_weight, new_weight))
        num_steps += 1
        if num_steps >= max_tests:
            break

test_disc_loss()

Generator Loss

def get_gen_loss(gen: Generator,
                 disc: Discriminator,
                 criterion: nn.BCEWithLogitsLoss,
                 num_images: int,
                 z_dim: int, device: str="cuda") -> torch.Tensor:
    """Calculates the loss for the generator

    Args:
       gen: the generator model, which returns an image given z-dimensional noise
       disc: the discriminator model, which returns a single-dimensional prediction of real/fake
       criterion: the loss function, which should be used to compare 
              the discriminator's predictions to the ground truth reality of the images 
              (e.g. fake = 0, real = 1)
       num_images: the number of images the generator should produce, 
               which is also the length of the real images
       z_dim: the dimension of the noise vector, a scalar
       device: the device type
    Returns:
       gen_loss: a torch scalar loss value for the current batch
    """
    noise = torch.randn(num_images, z_dim, device=device)
    fakes = gen(noise)
    fake_prediction = disc(fakes)
    gen_loss = criterion(fake_prediction, torch.ones_like(fake_prediction))
    return gen_loss
def test_gen_reasonable(num_images=10):
    # Don't use explicit casts to cuda - use the device argument
    import inspect, re
    lines = inspect.getsource(get_gen_loss)
    assert (re.search(r"to\(.cuda.\)", lines)) is None
    assert (re.search(r"\.cuda\(\)", lines)) is None

    z_dim = 64
    gen = torch.zeros_like
    disc = nn.Identity()
    criterion = torch.mul # Multiply
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(gen_loss_tensor) < 1e-5)
    #Verify shape. Related to gen_noise parametrization
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)

    gen = torch.ones_like
    disc = nn.Identity()
    criterion = torch.mul # Multiply
    real = torch.zeros(num_images, 1)
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(gen_loss_tensor - 1) < 1e-5)
    #Verify shape. Related to gen_noise parametrization
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)
    return
test_gen_reasonable(10)
def test_gen_loss(num_images):
    z_dim = 64
    gen = Generator(z_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc = Discriminator().to(device) 
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

    gen_loss = get_gen_loss(gen, disc, criterion, num_images, z_dim, device)

    # Check that the loss is reasonable
    assert (gen_loss - 0.7).abs() < 0.1
    gen_loss.backward()
    old_weight = gen.gen[0][0].weight.clone()
    gen_opt.step()
    new_weight = gen.gen[0][0].weight
    assert not torch.all(torch.eq(old_weight, new_weight))
test_gen_loss(18)

All Together

For each epoch, you will process the entire dataset in batches. For every batch, you will need to update the discriminator and generator using their loss. Batches are sets of images that will be predicted on before the loss functions are calculated (instead of calculating the loss function after each image). Note that you may see a loss to be greater than 1, this is okay since binary cross entropy loss can be any positive number for a sufficiently confident wrong guess.

It’s also often the case that the discriminator will outperform the generator, especially at the start, because its job is easier. It's important that neither one gets too good (that is, near-perfect accuracy), which would cause the entire model to stop learning. Balancing the two models is actually remarkably hard to do in a standard GAN and something you will see more of in later lectures and assignments.

After you've submitted a working version with the original architecture, feel free to play around with the architecture if you want to see how different architectural choices can lead to better or worse GANs. For example, consider changing the size of the hidden dimension, or making the networks shallower or deeper by changing the number of layers.

cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False
n_epochs = 2000
display_step = 4100
generator_losses = []
discriminator_losses = []
steps = []

with TIMER:
    for epoch in range(n_epochs):

        # Dataloader returns the batches
        for real, _ in dataloader:
            cur_batch_size = len(real)

            # Flatten the batch of real images from the dataset
            real = real.view(cur_batch_size, -1).to(device)

            ### Update discriminator ###
            # Zero out the gradients before backpropagation
            disc_opt.zero_grad()

            # Calculate discriminator loss
            disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

            # Update gradients
            disc_loss.backward(retain_graph=True)

            # Update optimizer
            disc_opt.step()

            # For testing purposes, to keep track of the generator weights
            if test_generator:
                old_generator_weights = gen.gen[0][0].weight.detach().clone()

            ### Update generator ###
            gen_opt.zero_grad()
            gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
            gen_loss.backward(retain_graph=True)
            gen_opt.step()

            # For testing purposes, to check that your code changes the generator weights
            if test_generator:
                try:
                    assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                    assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
                except:
                    error = True
                    print("Runtime tests have failed")

            # Keep track of the average discriminator loss
            mean_discriminator_loss += disc_loss.item() / display_step

            # Keep track of the average generator loss
            mean_generator_loss += gen_loss.item() / display_step

            if cur_step % display_step == 0 and cur_step > 0:
                print(f"Epoch {epoch}, step {cur_step}: Generator loss:"
                        f" {mean_generator_loss}, discriminator loss:"
                        f" {mean_discriminator_loss}")
                steps.append(cur_step)
                generator_losses.append(mean_generator_loss)
                discriminator_losses.append(mean_discriminator_loss)

                mean_generator_loss = 0
                mean_discriminator_loss = 0
            cur_step += 1
Started: 2021-04-08 19:11:09.461117
Epoch 5, step 2500: Generator loss: 1.706548154211052, discriminator loss: 0.2566282903790473
Epoch 10, step 5000: Generator loss: 4.417493268251426, discriminator loss: 0.37590573319792847
Epoch 15, step 7500: Generator loss: 8.217398338270204, discriminator loss: 0.4420946755893531
Epoch 21, step 10000: Generator loss: 12.379702310991277, discriminator loss: 0.49946175085604244
Epoch 26, step 12500: Generator loss: 16.363392679834355, discriminator loss: 0.577481897354875
Epoch 31, step 15000: Generator loss: 20.321313246965392, discriminator loss: 0.6705450104258994
Epoch 37, step 17500: Generator loss: 23.881395485830232, discriminator loss: 0.7909670361138917
Epoch 42, step 20000: Generator loss: 27.36178849205961, discriminator loss: 0.9245524749659035
Epoch 47, step 22500: Generator loss: 30.756254529428357, discriminator loss: 1.0683966985411961
Epoch 53, step 25000: Generator loss: 33.873566954183424, discriminator loss: 1.2374154507704083
Epoch 58, step 27500: Generator loss: 36.76653855376236, discriminator loss: 1.4335610504932743
Epoch 63, step 30000: Generator loss: 39.610195555067065, discriminator loss: 1.6299822613395802
Epoch 69, step 32500: Generator loss: 42.27110341444029, discriminator loss: 1.8545818868704136
Epoch 74, step 35000: Generator loss: 44.86730858569149, discriminator loss: 2.081926002264768
Epoch 79, step 37500: Generator loss: 47.34035383772865, discriminator loss: 2.3272732418782995
Epoch 85, step 40000: Generator loss: 49.69807465667742, discriminator loss: 2.5900223485894514
Epoch 90, step 42500: Generator loss: 51.95912191028614, discriminator loss: 2.856632189888516
Epoch 95, step 45000: Generator loss: 54.13774062051793, discriminator loss: 3.1388706683166423
Epoch 101, step 47500: Generator loss: 56.25892917881031, discriminator loss: 3.435482066709555
Epoch 106, step 50000: Generator loss: 58.2940666561604, discriminator loss: 3.742299897164866
Epoch 111, step 52500: Generator loss: 60.34588112130169, discriminator loss: 4.039923962772651
Epoch 117, step 55000: Generator loss: 62.3250578921796, discriminator loss: 4.3601536815710755
Epoch 122, step 57500: Generator loss: 64.21707550911917, discriminator loss: 4.693669865030843
Epoch 127, step 60000: Generator loss: 66.14931350994115, discriminator loss: 5.012754998887372
Epoch 133, step 62500: Generator loss: 68.01088003492343, discriminator loss: 5.350926510263262
Epoch 138, step 65000: Generator loss: 69.7833545449736, discriminator loss: 5.705678011608883
Epoch 143, step 67500: Generator loss: 71.56750503945366, discriminator loss: 6.058190715546184
Epoch 149, step 70000: Generator loss: 73.28055478563336, discriminator loss: 6.422111075831223
Epoch 154, step 72500: Generator loss: 74.93712217669513, discriminator loss: 6.801459926683468
Epoch 159, step 75000: Generator loss: 76.57140328321462, discriminator loss: 7.186894027460396
Epoch 165, step 77500: Generator loss: 78.11976942777646, discriminator loss: 7.59347033443528
Epoch 170, step 80000: Generator loss: 79.70259762425445, discriminator loss: 7.990671021056967
Epoch 175, step 82500: Generator loss: 81.29402320809406, discriminator loss: 8.38323437005357
Epoch 181, step 85000: Generator loss: 82.81459570746449, discriminator loss: 8.79464628630952
Epoch 186, step 87500: Generator loss: 84.37445686025625, discriminator loss: 9.187008984566525
Epoch 191, step 90000: Generator loss: 85.85529090266233, discriminator loss: 9.62971806451164
Epoch 197, step 92500: Generator loss: 87.28264569795147, discriminator loss: 10.0601266036578
Epoch 202, step 95000: Generator loss: 88.77136517236256, discriminator loss: 10.470760706252658
Epoch 207, step 97500: Generator loss: 90.20359932258185, discriminator loss: 10.896903645534154
Epoch 213, step 100000: Generator loss: 91.64949153683249, discriminator loss: 11.314317919439938
Epoch 218, step 102500: Generator loss: 93.04353729379224, discriminator loss: 11.754701970989366
Epoch 223, step 105000: Generator loss: 94.48434179880661, discriminator loss: 12.17579472559178
Epoch 229, step 107500: Generator loss: 95.90043624894685, discriminator loss: 12.609691903144922
Epoch 234, step 110000: Generator loss: 97.22790038921927, discriminator loss: 13.06728125928124
Epoch 239, step 112500: Generator loss: 98.54396488256565, discriminator loss: 13.530596924526252
Epoch 245, step 115000: Generator loss: 99.77238301303473, discriminator loss: 14.021221584183708
Epoch 250, step 117500: Generator loss: 100.95588274421799, discriminator loss: 14.52276291984987
Epoch 255, step 120000: Generator loss: 102.16950242385977, discriminator loss: 15.009217036432748
Epoch 261, step 122500: Generator loss: 103.34768993141779, discriminator loss: 15.510240219909676
Epoch 266, step 125000: Generator loss: 104.49850498325966, discriminator loss: 16.02100355718803
Epoch 271, step 127500: Generator loss: 105.63248440983429, discriminator loss: 16.529967280096464
Epoch 277, step 130000: Generator loss: 106.77215565025928, discriminator loss: 17.040578575879902
Epoch 282, step 132500: Generator loss: 107.84948109347918, discriminator loss: 17.577629218131907
Epoch 287, step 135000: Generator loss: 108.89120032596693, discriminator loss: 18.1230313348835
Epoch 293, step 137500: Generator loss: 109.94543989174456, discriminator loss: 18.66669545603442
Epoch 298, step 140000: Generator loss: 111.05784844493985, discriminator loss: 19.186080698353518
Epoch 303, step 142500: Generator loss: 112.11954127087729, discriminator loss: 19.72143556161576
Epoch 309, step 145000: Generator loss: 113.1505551135316, discriminator loss: 20.266378742129064
Epoch 314, step 147500: Generator loss: 114.14504244511286, discriminator loss: 20.82540130450163
Epoch 319, step 150000: Generator loss: 115.14619642219694, discriminator loss: 21.377254920214682
Epoch 325, step 152500: Generator loss: 116.18683508672831, discriminator loss: 21.91325990107692
Epoch 330, step 155000: Generator loss: 117.20365596852427, discriminator loss: 22.464628956920198
Epoch 335, step 157500: Generator loss: 118.18476380837096, discriminator loss: 23.02610586987158
Epoch 341, step 160000: Generator loss: 119.1733443738712, discriminator loss: 23.57863494028462
Epoch 346, step 162500: Generator loss: 120.0683158794178, discriminator loss: 24.17668376757525
Epoch 351, step 165000: Generator loss: 121.00822785911677, discriminator loss: 24.759683359057014
Epoch 357, step 167500: Generator loss: 122.01716691696792, discriminator loss: 25.306333132869433
Epoch 362, step 170000: Generator loss: 122.95615759954634, discriminator loss: 25.882448339409144
Epoch 367, step 172500: Generator loss: 123.8601865704076, discriminator loss: 26.472507165831818
Epoch 373, step 175000: Generator loss: 124.77679342136544, discriminator loss: 27.06009588730988
Epoch 378, step 177500: Generator loss: 125.68742787552021, discriminator loss: 27.650220775634565
Epoch 383, step 180000: Generator loss: 126.58067360401337, discriminator loss: 28.24680029824429
Epoch 389, step 182500: Generator loss: 127.47687033252896, discriminator loss: 28.841687268430586
Epoch 394, step 185000: Generator loss: 128.33945010419103, discriminator loss: 29.470305139536162
Epoch 399, step 187500: Generator loss: 129.2645899116776, discriminator loss: 30.056467109007148
Epoch 405, step 190000: Generator loss: 130.17483343897044, discriminator loss: 30.65010625776692
Epoch 410, step 192500: Generator loss: 131.07306051247323, discriminator loss: 31.24635483381194
Epoch 415, step 195000: Generator loss: 131.98270811326725, discriminator loss: 31.83840948063775
Epoch 421, step 197500: Generator loss: 132.89332210309777, discriminator loss: 32.427245197707826
Epoch 426, step 200000: Generator loss: 133.85536700406317, discriminator loss: 32.99699493227643
Epoch 431, step 202500: Generator loss: 134.78184310503255, discriminator loss: 33.582932585127054
Epoch 437, step 205000: Generator loss: 135.67147595069721, discriminator loss: 34.18765857823589
Epoch 442, step 207500: Generator loss: 136.59887173002065, discriminator loss: 34.770955285043314
Epoch 447, step 210000: Generator loss: 137.53319753001017, discriminator loss: 35.35211720700919
Epoch 453, step 212500: Generator loss: 138.4486942873986, discriminator loss: 35.94530614820166
Epoch 458, step 215000: Generator loss: 139.33902888264944, discriminator loss: 36.54999590321232
Epoch 463, step 217500: Generator loss: 140.27206793022475, discriminator loss: 37.13930196701909
Epoch 469, step 220000: Generator loss: 140.99687212999171, discriminator loss: 37.93871315395262
Epoch 474, step 222500: Generator loss: 141.7673044035706, discriminator loss: 38.59500126797581
Epoch 479, step 225000: Generator loss: 142.5946069137365, discriminator loss: 39.21644382466705
Epoch 485, step 227500: Generator loss: 143.47817903824173, discriminator loss: 39.82380570806908
Epoch 490, step 230000: Generator loss: 144.42988614442692, discriminator loss: 40.41459694806965
Epoch 495, step 232500: Generator loss: 145.41308410630532, discriminator loss: 40.99819621782919
Epoch 501, step 235000: Generator loss: 146.35154331310105, discriminator loss: 41.60116203337314
Epoch 506, step 237500: Generator loss: 147.3414293385067, discriminator loss: 42.182913084965904
Epoch 511, step 240000: Generator loss: 148.16208219452346, discriminator loss: 42.84799684844597
Epoch 517, step 242500: Generator loss: 149.01690332217666, discriminator loss: 43.4645494998036
Epoch 522, step 245000: Generator loss: 149.90361780090743, discriminator loss: 44.07395624421216
Epoch 527, step 247500: Generator loss: 150.81761934249764, discriminator loss: 44.67337528136382
Epoch 533, step 250000: Generator loss: 151.760788434009, discriminator loss: 45.27149211621905
Epoch 538, step 252500: Generator loss: 152.72171067755622, discriminator loss: 45.862570312196205
Epoch 543, step 255000: Generator loss: 153.70058994908774, discriminator loss: 46.446673588067306
Epoch 549, step 257500: Generator loss: 154.6540338594482, discriminator loss: 47.04409442761572
Epoch 554, step 260000: Generator loss: 155.61953176954265, discriminator loss: 47.64129806395185
Epoch 559, step 262500: Generator loss: 156.57851000073427, discriminator loss: 48.23621501955407
Epoch 565, step 265000: Generator loss: 157.57013796937912, discriminator loss: 48.81841204716583
Epoch 570, step 267500: Generator loss: 158.58745419824487, discriminator loss: 49.39728153505935
Epoch 575, step 270000: Generator loss: 159.61897925506034, discriminator loss: 49.97088020893944
Epoch 581, step 272500: Generator loss: 160.62414082653933, discriminator loss: 50.55845826935191
Epoch 586, step 275000: Generator loss: 161.6360176999136, discriminator loss: 51.13836716422452
Epoch 591, step 277500: Generator loss: 162.6937001745749, discriminator loss: 51.703352506631745
Epoch 597, step 280000: Generator loss: 163.713712522297, discriminator loss: 52.28744219620823
Epoch 602, step 282500: Generator loss: 164.74417761338188, discriminator loss: 52.86771041123264
Epoch 607, step 285000: Generator loss: 165.76970245681252, discriminator loss: 53.44563473485114
Epoch 613, step 287500: Generator loss: 166.811878925066, discriminator loss: 54.01528675569889
Epoch 618, step 290000: Generator loss: 167.8307581976232, discriminator loss: 54.590264578408025
Epoch 623, step 292500: Generator loss: 168.891059790879, discriminator loss: 55.156935418433086
Epoch 628, step 295000: Generator loss: 169.9474142856893, discriminator loss: 55.71990455088043
Epoch 634, step 297500: Generator loss: 171.0481772922807, discriminator loss: 56.27077199867391
Epoch 639, step 300000: Generator loss: 172.11923127556335, discriminator loss: 56.83505055757194
Epoch 644, step 302500: Generator loss: 173.16847663276675, discriminator loss: 57.40616466661116
Epoch 650, step 305000: Generator loss: 174.07904141102364, discriminator loss: 58.13510520734215
Epoch 655, step 307500: Generator loss: 174.9117115099017, discriminator loss: 58.74583771123328
Epoch 660, step 310000: Generator loss: 175.81728609905886, discriminator loss: 59.34154992217456
Epoch 666, step 312500: Generator loss: 176.7578212393589, discriminator loss: 59.92914648076939
Epoch 671, step 315000: Generator loss: 177.6959055232357, discriminator loss: 60.52539835767199
Epoch 676, step 317500: Generator loss: 178.65148061598035, discriminator loss: 61.114570335239605
Epoch 682, step 320000: Generator loss: 179.61467326215066, discriminator loss: 61.701995908111826
Epoch 687, step 322500: Generator loss: 180.57620892596987, discriminator loss: 62.314891455287345
Epoch 692, step 325000: Generator loss: 181.45410679765476, discriminator loss: 62.936326666588116
Epoch 698, step 327500: Generator loss: 182.3509983194426, discriminator loss: 63.536605342692496
Epoch 703, step 330000: Generator loss: 183.2767872462343, discriminator loss: 64.13359040188224
Epoch 708, step 332500: Generator loss: 184.24614562447758, discriminator loss: 64.71242399436842
Epoch 714, step 335000: Generator loss: 185.2339295223548, discriminator loss: 65.29031219341155
Epoch 719, step 337500: Generator loss: 186.22102292967574, discriminator loss: 65.87176483958355
Epoch 724, step 340000: Generator loss: 187.16576725860415, discriminator loss: 66.46805416325945
Epoch 730, step 342500: Generator loss: 188.17708793676715, discriminator loss: 67.02993150789138
Epoch 735, step 345000: Generator loss: 189.2133280149301, discriminator loss: 67.59322807443746
Epoch 740, step 347500: Generator loss: 190.23208963008815, discriminator loss: 68.16751613625887
Epoch 746, step 350000: Generator loss: 191.25920752848035, discriminator loss: 68.7402887957158
Epoch 751, step 352500: Generator loss: 192.28080943237188, discriminator loss: 69.31388918965479
Epoch 756, step 355000: Generator loss: 193.3113937983357, discriminator loss: 69.88720360422755
Epoch 762, step 357500: Generator loss: 194.36098039241693, discriminator loss: 70.46414442196532
Epoch 767, step 360000: Generator loss: 195.2757543816169, discriminator loss: 71.13062453675919
Epoch 772, step 362500: Generator loss: 196.20393474234035, discriminator loss: 71.72337642310346
Epoch 778, step 365000: Generator loss: 197.17895898321507, discriminator loss: 72.2954984176345
Epoch 783, step 367500: Generator loss: 198.16447116800182, discriminator loss: 72.87443887000731
Epoch 788, step 370000: Generator loss: 199.0904656322084, discriminator loss: 73.4706785914252
Epoch 794, step 372500: Generator loss: 200.0776066501466, discriminator loss: 74.04545851565092
Epoch 799, step 375000: Generator loss: 201.02695143798252, discriminator loss: 74.6362730719872
Epoch 804, step 377500: Generator loss: 201.990468873819, discriminator loss: 75.22039345236489
Epoch 810, step 380000: Generator loss: 203.01058207302907, discriminator loss: 75.78494052414334
Epoch 815, step 382500: Generator loss: 204.05225880001382, discriminator loss: 76.35463489610616
Epoch 820, step 385000: Generator loss: 205.01460667336755, discriminator loss: 76.94180617091138
Epoch 826, step 387500: Generator loss: 206.05344676898304, discriminator loss: 77.51117122195475
Epoch 831, step 390000: Generator loss: 207.07375686963422, discriminator loss: 78.09562644241576
Epoch 836, step 392500: Generator loss: 208.0923817031217, discriminator loss: 78.6823844633292
Epoch 842, step 395000: Generator loss: 209.13077087057394, discriminator loss: 79.25322038175547
Epoch 847, step 397500: Generator loss: 210.17313802006993, discriminator loss: 79.82443133078309
Epoch 852, step 400000: Generator loss: 211.24034579869058, discriminator loss: 80.39010535690288
Epoch 858, step 402500: Generator loss: 212.29597073660645, discriminator loss: 80.95657187560273
Epoch 863, step 405000: Generator loss: 213.36997850652477, discriminator loss: 81.5198167693678
Epoch 868, step 407500: Generator loss: 214.34861638153268, discriminator loss: 82.11812956745074
Epoch 874, step 410000: Generator loss: 215.34710292697525, discriminator loss: 82.68937682878448
Epoch 879, step 412500: Generator loss: 216.41345989072963, discriminator loss: 83.24355317315504
Epoch 884, step 415000: Generator loss: 217.28458412296192, discriminator loss: 83.92686493608393
Epoch 890, step 417500: Generator loss: 218.29466441413823, discriminator loss: 84.48695280879178
Epoch 895, step 420000: Generator loss: 219.3093764458604, discriminator loss: 85.05175313127683
Epoch 900, step 422500: Generator loss: 220.34153808440644, discriminator loss: 85.61656451819609
Epoch 906, step 425000: Generator loss: 221.36335393692892, discriminator loss: 86.20032007145346
Epoch 911, step 427500: Generator loss: 222.4004497556632, discriminator loss: 86.77878700938892
Epoch 916, step 430000: Generator loss: 223.45292786750198, discriminator loss: 87.35415507568788
Epoch 922, step 432500: Generator loss: 224.49835159925806, discriminator loss: 87.92856502636128
Epoch 927, step 435000: Generator loss: 225.5315329488691, discriminator loss: 88.51101210754575
Epoch 932, step 437500: Generator loss: 226.5669734054742, discriminator loss: 89.09301339096481
Epoch 938, step 440000: Generator loss: 227.48296755303724, discriminator loss: 89.70019514662671
Epoch 943, step 442500: Generator loss: 228.48196144422903, discriminator loss: 90.27654089993891
Epoch 948, step 445000: Generator loss: 229.51780844437448, discriminator loss: 90.83335400045588
Epoch 954, step 447500: Generator loss: 230.60036065892555, discriminator loss: 91.38414017357304
Epoch 959, step 450000: Generator loss: 231.67713544449177, discriminator loss: 91.94271953728851
Epoch 964, step 452500: Generator loss: 232.64530618559718, discriminator loss: 92.56010472605873
Epoch 970, step 455000: Generator loss: 233.68903298704024, discriminator loss: 93.11562871940767
Epoch 975, step 457500: Generator loss: 234.72853977193282, discriminator loss: 93.68664950763602
Epoch 980, step 460000: Generator loss: 235.69478953835426, discriminator loss: 94.26967884075088
Epoch 986, step 462500: Generator loss: 236.70701852056465, discriminator loss: 94.835702462167
Epoch 991, step 465000: Generator loss: 237.76209552497326, discriminator loss: 95.39685135828884
Epoch 996, step 467500: Generator loss: 238.79892911128448, discriminator loss: 95.97513031529836
Ended: 2021-04-08 20:55:49.865623
Elapsed: 1:44:40.404506

Looking at the Final model.

def plot_image(image: torch.Tensor,
                filename: str,
                title: str,
                num_images: int=25,
                size: tuple=(1, 28, 28),
                folder: str="files/posts/gans/mnist-gan/") -> None:
    """Plot the image and save it

    Args:
     image: the tensor with the image to plot
     filename: name for the final image file
     title: title to put on top of the image
     num_images: how many images to put in the composite image
     size: the size for the image
     folder: sub-folder to save the file in
    """
    unflattened_image = image.detach().cpu().view(-1, *size)
    image_grid = make_grid(unflattened_image[:num_images], nrow=5)

    pyplot.title(title)
    pyplot.grid(False)
    pyplot.imshow(image_grid.permute(1, 2, 0).squeeze())

    pyplot.tick_params(bottom=False, top=False, labelbottom=False,
                       right=False, left=False, labelleft=False)
    pyplot.savefig(folder + filename)
    print(f"[[file:{filename}]]")
    return
fake_noise = get_noise(cur_batch_size, z_dim, device=device)
fake = gen(fake_noise)
plot_image(image=fake, filename="fake_digits.png", title="Fake Digits")

fake_digits.png

plot_image(real, filename="real_digits.png", title="Real Digits")

real_digits.png

plotting = pandas.DataFrame.from_dict({
    "Step": steps,
    "Generator Loss": generator_losses,
    "Discriminator Loss": discriminator_losses
})

gen_plot = plotting.hvplot(x="Step", y="Generator Loss", color=PLOT.blue)
disc_plot = plotting.hvplot(x="Step", y="Discriminator Loss", color=PLOT.red)

plot = (gen_plot * disc_plot).opts(title="Training Losses",
                                   height=PLOT.height,
                                   width=PLOT.width,
                                   ylabel="Loss",
                                   fontscale=PLOT.fontscale)
output = Embed(plot=plot, file_name="losses")()
print(output)

Figure Missing

I thought something was wrong with the losses, at first, since they seem to go up over time, but the loss is based on the Generator and the Discriminator being able to do their job, so as they get better, the loss goes up. The main one for us to note is the Discriminator loss, since this is how much it gets fooled by the Generator. Since it's still going up this likely means that the Generator can still improve.

End

Neural Machine Translation: Helper Functions

Helper Functions

We will first implement a few functions that we will use later on. These will be for:

  • the input encoder
  • the pre-attention decoder
  • preparation of the queries, keys, values, and mask.

Imports

# from pypi
from trax import layers
from trax.fastmath import numpy as fastmath_numpy

import trax

Helper functions

Input encoder

The input encoder runs on the input tokens, creates its embeddings, and feeds it to an LSTM network. This outputs the activations that will be the keys and values for attention. It is a Serial network which uses:

  • tl.Embedding: Converts each token to its vector representation. In this case, it is the the size of the vocabulary by the dimension of the model: tl.Embedding(vocab_size, d_model). vocab_size is the number of entries in the given vocabulary. d_model is the number of elements in the word embedding.
  • tl.LSTM: LSTM layer of size d_model. We want to be able to configure how many encoder layers we have so remember to create LSTM layers equal to the number of the n_encoder_layers parameter.
def input_encoder(input_vocab_size: int, d_model: int,
                     n_encoder_layers: int) -> layers.Serial:
    """ Input encoder runs on the input sentence and creates
    activations that will be the keys and values for attention.

    Args:
       input_vocab_size: vocab size of the input
       d_model:  depth of embedding (n_units in the LSTM cell)
       n_encoder_layers: number of LSTM layers in the encoder

    Returns:
       tl.Serial: The input encoder
    """
    input_encoder = layers.Serial( 
        layers.Embedding(input_vocab_size, d_model),
        [layers.LSTM(d_model) for _ in range(n_encoder_layers)]
    )
    return input_encoder
def test_input_encoder_fn(input_encoder_fn):
    target = input_encoder_fn
    success = 0
    fails = 0

    input_vocab_size = 10
    d_model = 2
    n_encoder_layers = 6

    encoder = target(input_vocab_size, d_model, n_encoder_layers)

    lstms = "\n".join([f'  LSTM_{d_model}'] * n_encoder_layers)

    expected = f"Serial[\n  Embedding_{input_vocab_size}_{d_model}\n{lstms}\n]"

    proposed = str(encoder)

    # Test all layers are in the expected sequence
    try:
        assert(proposed.replace(" ", "") == expected.replace(" ", ""))
        success += 1
    except:
        fails += 1
        print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)

    # Test the output type
    try:
        assert(isinstance(encoder, trax.layers.combinators.Serial))
        success += 1
        # Test the number of layers
        try:
            # Test 
            assert len(encoder.sublayers) == (n_encoder_layers + 1)
            success += 1
        except:
            fails += 1
            print('The number of sublayers does not match %s <>' %len(encoder.sublayers), " %s" %(n_encoder_layers + 1))
    except:
        fails += 1
        print("The enconder is not an object of ", trax.layers.combinators.Serial)


    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
test_input_encoder_fn(input_encoder)
[92m All tests passed

Pre-attention decoder

The pre-attention decoder runs on the targets and creates activations that are used as queries in attention. This is a Serial network which is composed of the following:

  • tl.ShiftRight: This pads a token to the beginning of your target tokens (e.g. [8, 34, 12] shifted right is [0, 8, 34, 12]). This will act like a start-of-sentence token that will be the first input to the decoder. During training, this shift also allows the target tokens to be passed as input to do teacher forcing.
  • tl.Embedding: Like in the previous function, this converts each token to its vector representation. In this case, it is the the size of the vocabulary by the dimension of the model: tl.Embedding(vocab_size, d_model). vocab_size is the number of entries in the given vocabulary. d_model is the number of elements in the word embedding.
  • tl.LSTM: LSTM layer of size d_model.
def pre_attention_decoder(mode: str, target_vocab_size: int, d_model: int) -> layers.Serial:
    """ Pre-attention decoder runs on the targets and creates
    activations that are used as queries in attention.

    Args:
       mode: 'train' or 'eval'
       target_vocab_size: vocab size of the target
       d_model:  depth of embedding (n_units in the LSTM cell)
    Returns:
       tl.Serial: The pre-attention decoder
    """
    return layers.Serial(
        layers.ShiftRight(mode=mode),
        layers.Embedding(target_vocab_size, d_model),
        layers.LSTM(d_model)
    )
def test_pre_attention_decoder_fn(pre_attention_decoder_fn):
    target = pre_attention_decoder_fn
    success = 0
    fails = 0

    mode = 'train'
    target_vocab_size = 10
    d_model = 2

    decoder = target(mode, target_vocab_size, d_model)

    expected = f"Serial[\n  ShiftRight(1)\n  Embedding_{target_vocab_size}_{d_model}\n  LSTM_{d_model}\n]"

    proposed = str(decoder)

    # Test all layers are in the expected sequence
    try:
        assert(proposed.replace(" ", "") == expected.replace(" ", ""))
        success += 1
    except:
        fails += 1
        print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)

    # Test the output type
    try:
        assert(isinstance(decoder, trax.layers.combinators.Serial))
        success += 1
        # Test the number of layers
        try:
            # Test 
            assert len(decoder.sublayers) == 3
            success += 1
        except:
            fails += 1
            print('The number of sublayers does not match %s <>' %len(decoder.sublayers), " %s" %3)
    except:
        fails += 1
        print("The enconder is not an object of ", trax.layers.combinators.Serial)


    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")

They changed the behavior of the Fn (or something in there) so that it always wraps the ShiftRight in a Serial layer, so it doesn't match the test anymore. Testing strings is kind of gimpy anyway…

It looks like they're using a decorator to check the shape which then wraps it in a Serial layer. See trax.layers.assert_shape.AssertFunction

test_pre_attention_decoder_fn(pre_attention_decoder)
Wrong model. 
Proposed:
Serial[
  Serial[
    ShiftRight(1)
  ]
  Embedding_10_2
  LSTM_2
] 
Expected:
Serial[
  ShiftRight(1)
  Embedding_10_2
  LSTM_2
]
[92m 2  Tests passed
[91m 1  Tests failed

Preparing the attention input

This function will prepare the inputs to the attention layer. We want to take in the encoder and pre-attention decoder activations and assign it to the queries, keys, and values. In addition, another output here will be the mask to distinguish real tokens from padding tokens. This mask will be used internally by Trax when computing the softmax so padding tokens will not have an effect on the computated probabilities. From the data preparation steps in Section 1 of this assignment, you should know which tokens in the input correspond to padding.

def prepare_attention_input(encoder_activations: fastmath_numpy.array,
                            decoder_activations: fastmath_numpy.array,
                            inputs: fastmath_numpy.array) -> tuple:
    """Prepare queries, keys, values and mask for attention.

    Args:
       encoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the input encoder
       decoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the pre-attention decoder
       inputs fastnp.array(batch_size, padded_input_length): padded input tokens

    Returns:
       queries, keys, values and mask for attention.
    """
    keys = encoder_activations
    values = encoder_activations
    queries = decoder_activations    
    mask = inputs != 0

    mask = fastmath_numpy.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    mask += fastmath_numpy.zeros((1, 1, decoder_activations.shape[1], 1))
    return queries, keys, values, mask
def test_prepare_attention_input(prepare_attention_input):
    target = prepare_attention_input
    success = 0
    fails = 0

    #This unit test consider a batch size = 2, number_of_tokens = 3 and embedding_size = 4

    enc_act = fastmath_numpy.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
               [[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 0]]])
    dec_act = fastmath_numpy.array([[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0]], 
               [[2, 0, 2, 0], [0, 2, 0, 2], [0, 0, 0, 0]]])
    inputs =  fastmath_numpy.array([[1, 2, 3], [1, 4, 0]])

    exp_mask = fastmath_numpy.array([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]], 
                             [[[1., 1., 0.], [1., 1., 0.], [1., 1., 0.]]]])

    exp_type = type(enc_act)

    queries, keys, values, mask = target(enc_act, dec_act, inputs)

    try:
        assert(fastmath_numpy.allclose(queries, dec_act))
        success += 1
    except:
        fails += 1
        print("Queries does not match the decoder activations")
    try:
        assert(fastmath_numpy.allclose(keys, enc_act))
        success += 1
    except:
        fails += 1
        print("Keys does not match the encoder activations")
    try:
        assert(fastmath_numpy.allclose(values, enc_act))
        success += 1
    except:
        fails += 1
        print("Values does not match the encoder activations")
    try:
        assert(fastmath_numpy.allclose(mask, exp_mask))
        success += 1
    except:
        fails += 1
        print("Mask does not match expected tensor. \nExpected:\n%s" %exp_mask, "\nOutput:\n%s" %mask)

    # Test the output type
    try:
        assert(isinstance(queries, exp_type))
        assert(isinstance(keys, exp_type))
        assert(isinstance(values, exp_type))
        assert(isinstance(mask, exp_type))
        success += 1
    except:
        fails += 1
        print("One of the output object are not of type ", jax.interpreters.xla.DeviceArray)

    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
test_prepare_attention_input(prepare_attention_input)
[92m All tests passed

Neural Machine Translation: Testing the Model

Table of Contents

Testing the Model

In the previous post we trained our machine translation model so now it's time to test it and see how well it does.

End

The overview post with links to all the posts in this series is here.

Raw

# # Part 4:  Testing
# 
# We will now be using the model you just trained to translate English sentences to German. We will implement this with two functions: The first allows you to identify the next symbol (i.e. output token). The second one takes care of combining the entire translated string.
# 
# We will start by first loading in a pre-trained copy of the model you just coded. Please run the cell below to do just that.

# In[ ]:


# instantiate the model we built in eval mode
model = NMTAttn(mode='eval')

# initialize weights from a pre-trained model
model.init_from_file("model.pkl.gz", weights_only=True)
model = tl.Accelerate(model)


# <a name="4.1"></a>
# ## 4.1  Decoding
# 
# As discussed in the lectures, there are several ways to get the next token when translating a sentence. For instance, we can just get the most probable token at each step (i.e. greedy decoding) or get a sample from a distribution. We can generalize the implementation of these two approaches by using the `tl.logsoftmax_sample()` method. Let's briefly look at its implementation:
# 
# ```python
# def logsoftmax_sample(log_probs, temperature=1.0):  # pylint: disable=invalid-name
#   """Returns a sample from a log-softmax output, with temperature.
# 
#   Args:
#     log_probs: Logarithms of probabilities (often coming from LogSofmax)
#     temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax)
#   """
#   # This is equivalent to sampling from a softmax with temperature.
#   u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
#   g = -np.log(-np.log(u))
#   return np.argmax(log_probs + g * temperature, axis=-1)
# ```
# 
# The key things to take away here are: 1. it gets random samples with the same shape as your input (i.e. `log_probs`), and 2. the amount of "noise" added to the input by these random samples is scaled by a `temperature` setting. You'll notice that setting it to `0` will just make the return statement equal to getting the argmax of `log_probs`. This will come in handy later. 
# 
# <a name="ex06"></a>
# ### Exercise 06
# 
# **Instructions:** Implement the `next_symbol()` function that takes in the `input_tokens` and the `cur_output_tokens`, then return the index of the next word. You can click below for hints in completing this exercise.
# 
# <details>    
# <summary>
#     <font size="3" color="darkgreen"><b>Click Here for Hints</b></font>
# </summary>
# <p>
# <ul>
#     <li>To get the next power of two, you can compute <i>2^log_2(token_length + 1)</i> . We add 1 to avoid <i>log(0).</i></li>
#     <li>You can use <i>np.ceil()</i> to get the ceiling of a float.</li>
#     <li><i>np.log2()</i> will get the logarithm base 2 of a value</li>
#     <li><i>int()</i> will cast a value into an integer type</li>
#     <li>From the model diagram in part 2, you know that it takes two inputs. You can feed these with this syntax to get the model outputs: <i>model((input1, input2))</i>. It's up to you to determine which variables below to substitute for input1 and input2. Remember also from the diagram that the output has two elements: [log probabilities, target tokens]. You won't need the target tokens so we assigned it to _ below for you. </li>
#     <li> The log probabilities output will have the shape: (batch size, decoder length, vocab size). It will contain log probabilities for each token in the <i>cur_output_tokens</i> plus 1 for the start symbol introduced by the ShiftRight in the preattention decoder. For example, if cur_output_tokens is [1, 2, 5], the model will output an array of log probabilities each for tokens 0 (start symbol), 1, 2, and 5. To generate the next symbol, you just want to get the log probabilities associated with the last token (i.e. token 5 at index 3). You can slice the model output at [0, 3, :] to get this. It will be up to you to generalize this for any length of cur_output_tokens </li>
# </ul>
# 

# In[ ]:


# UNQ_C6
# GRADED FUNCTION
def next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature):
    """Returns the index of the next token.

    Args:
        NMTAttn (tl.Serial): An LSTM sequence-to-sequence model with attention.
        input_tokens (np.ndarray 1 x n_tokens): tokenized representation of the input sentence
        cur_output_tokens (list): tokenized representation of previously translated words
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        int: index of the next token in the translated sentence
        float: log probability of the next symbol
    """

    ### START CODE HERE (REPLACE INSTANCES OF `None` WITH YOUR CODE) ###

    # set the length of the current output tokens
    token_length = None

    # calculate next power of 2 for padding length 
    padded_length = None

    # pad cur_output_tokens up to the padded_length
    padded = cur_output_tokens + None
    
    # model expects the output to have an axis for the batch size in front so
    # convert `padded` list to a numpy array with shape (x, <padded_length>) where the
    # x position is the batch axis. (hint: you can use np.expand_dims() with axis=0 to insert a new axis)
    padded_with_batch = None

    # get the model prediction. remember to use the `NMTAttn` argument defined above.
    # hint: the model accepts a tuple as input (e.g. `my_model((input1, input2))`)
    output, _ = None
    
    # get log probabilities from the last token output
    log_probs = output[None]

    # get the next symbol by getting a logsoftmax sample (*hint: cast to an int)
    symbol = None
    
    ### END CODE HERE ###

    return symbol, float(log_probs[symbol])


# In[ ]:


# BEGIN UNIT TEST
w1_unittest.test_next_symbol(next_symbol, model)
# END UNIT TEST


# Now you will implement the `sampling_decode()` function. This will call the `next_symbol()` function above several times until the next output is the end-of-sentence token (i.e. `EOS`). It takes in an input string and returns the translated version of that string.
# 
# <a name="ex07"></a>
# ### Exercise 07
# 
# **Instructions**: Implement the `sampling_decode()` function.

# In[ ]:


# UNQ_C7
# GRADED FUNCTION
def sampling_decode(input_sentence, NMTAttn = None, temperature=0.0, vocab_file=None, vocab_dir=None):
    """Returns the translated sentence.

    Args:
        input_sentence (str): sentence to translate.
        NMTAttn (tl.Serial): An LSTM sequence-to-sequence model with attention.
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)
        vocab_file (str): filename of the vocabulary
        vocab_dir (str): path to the vocabulary file

    Returns:
        tuple: (list, str, float)
            list of int: tokenized version of the translated sentence
            float: log probability of the translated sentence
            str: the translated sentence
    """
    
    ### START CODE HERE (REPLACE INSTANCES OF `None` WITH YOUR CODE) ###
    
    # encode the input sentence
    input_tokens = None
    
    # initialize the list of output tokens
    cur_output_tokens = None
    
    # initialize an integer that represents the current output index
    cur_output = None
    
    # Set the encoding of the "end of sentence" as 1
    EOS = None
    
    # check that the current output is not the end of sentence token
    while cur_output != EOS:
        
        # update the current output token by getting the index of the next word (hint: use next_symbol)
        cur_output, log_prob = None
        
        # append the current output token to the list of output tokens
        cur_output_tokens.append(cur_output)
    
    # detokenize the output tokens
    sentence = None
    
    ### END CODE HERE ###
    
    return cur_output_tokens, log_prob, sentence


# In[ ]:


# Test the function above. Try varying the temperature setting with values from 0 to 1.
# Run it several times with each setting and see how often the output changes.
sampling_decode("I love languages.", model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)


# In[ ]:


# BEGIN UNIT TEST
w1_unittest.test_sampling_decode(sampling_decode, model)
# END UNIT TEST


# We have set a default value of `0` to the temperature setting in our implementation of `sampling_decode()` above. As you may have noticed in the `logsoftmax_sample()` method, this setting will ultimately result in greedy decoding. As mentioned in the lectures, this algorithm generates the translation by getting the most probable word at each step. It gets the argmax of the output array of your model and then returns that index. See the testing function and sample inputs below. You'll notice that the output will remain the same each time you run it.

# In[ ]:


def greedy_decode_test(sentence, NMTAttn=None, vocab_file=None, vocab_dir=None):
    """Prints the input and output of our NMTAttn model using greedy decode

    Args:
        sentence (str): a custom string.
        NMTAttn (tl.Serial): An LSTM sequence-to-sequence model with attention.
        vocab_file (str): filename of the vocabulary
        vocab_dir (str): path to the vocabulary file

    Returns:
        str: the translated sentence
    """
    
    _,_, translated_sentence = sampling_decode(sentence, NMTAttn, vocab_file=vocab_file, vocab_dir=vocab_dir)
    
    print("English: ", sentence)
    print("German: ", translated_sentence)
    
    return translated_sentence


# In[ ]:


# put a custom string here
your_sentence = 'I love languages.'

greedy_decode_test(your_sentence, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR);


# In[ ]:


greedy_decode_test('You are almost done with the assignment!', model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR);


# <a name="4.2"></a>
# ## 4.2  Minimum Bayes-Risk Decoding
# 
# As mentioned in the lectures, getting the most probable token at each step may not necessarily produce the best results. Another approach is to do Minimum Bayes Risk Decoding or MBR. The general steps to implement this are:
# 
# 1. take several random samples
# 2. score each sample against all other samples
# 3. select the one with the highest score
# 
# You will be building helper functions for these steps in the following sections.

# <a name='4.2.1'></a>
# ### 4.2.1 Generating samples
# 
# First, let's build a function to generate several samples. You can use the `sampling_decode()` function you developed earlier to do this easily. We want to record the token list and log probability for each sample as these will be needed in the next step.

# In[ ]:


def generate_samples(sentence, n_samples, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None):
    """Generates samples using sampling_decode()

    Args:
        sentence (str): sentence to translate.
        n_samples (int): number of samples to generate
        NMTAttn (tl.Serial): An LSTM sequence-to-sequence model with attention.
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)
        vocab_file (str): filename of the vocabulary
        vocab_dir (str): path to the vocabulary file
        
    Returns:
        tuple: (list, list)
            list of lists: token list per sample
            list of floats: log probability per sample
    """
    # define lists to contain samples and probabilities
    samples, log_probs = [], []

    # run a for loop to generate n samples
    for _ in range(n_samples):
        
        # get a sample using the sampling_decode() function
        sample, logp, _ = sampling_decode(sentence, NMTAttn, temperature, vocab_file=vocab_file, vocab_dir=vocab_dir)
        
        # append the token list to the samples list
        samples.append(sample)
        
        # append the log probability to the log_probs list
        log_probs.append(logp)
                
    return samples, log_probs


# In[ ]:


# generate 4 samples with the default temperature (0.6)
generate_samples('I love languages.', 4, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)


# ### 4.2.2 Comparing overlaps
# 
# Let us now build our functions to compare a sample against another. There are several metrics available as shown in the lectures and you can try experimenting with any one of these. For this assignment, we will be calculating scores for unigram overlaps. One of the more simple metrics is the [Jaccard similarity](https://en.wikipedia.org/wiki/Jaccard_index) which gets the intersection over union of two sets. We've already implemented it below for your perusal.

# In[ ]:


def jaccard_similarity(candidate, reference):
    """Returns the Jaccard similarity between two token lists

    Args:
        candidate (list of int): tokenized version of the candidate translation
        reference (list of int): tokenized version of the reference translation

    Returns:
        float: overlap between the two token lists
    """
    
    # convert the lists to a set to get the unique tokens
    can_unigram_set, ref_unigram_set = set(candidate), set(reference)  
    
    # get the set of tokens common to both candidate and reference
    joint_elems = can_unigram_set.intersection(ref_unigram_set)
    
    # get the set of all tokens found in either candidate or reference
    all_elems = can_unigram_set.union(ref_unigram_set)
    
    # divide the number of joint elements by the number of all elements
    overlap = len(joint_elems) / len(all_elems)
    
    return overlap


# In[ ]:


# let's try using the function. remember the result here and compare with the next function below.
jaccard_similarity([1, 2, 3], [1, 2, 3, 4])


# One of the more commonly used metrics in machine translation is the ROUGE score. For unigrams, this is called ROUGE-1 and as shown in class, you can output the scores for both precision and recall when comparing two samples. To get the final score, you will want to compute the F1-score as given by:
# 
# $$score = 2* \frac{(precision * recall)}{(precision + recall)}$$
# 
# <a name="ex08"></a>
# ### Exercise 08
# 
# **Instructions**: Implement the `rouge1_similarity()` function.

# In[ ]:


# UNQ_C8
# GRADED FUNCTION

# for making a frequency table easily
from collections import Counter

def rouge1_similarity(system, reference):
    """Returns the ROUGE-1 score between two token lists

    Args:
        system (list of int): tokenized version of the system translation
        reference (list of int): tokenized version of the reference translation

    Returns:
        float: overlap between the two token lists
    """    
    
    ### START CODE HERE (REPLACE INSTANCES OF `None` WITH YOUR CODE) ###
    
    # make a frequency table of the system tokens (hint: use the Counter class)
    sys_counter = None
    
    # make a frequency table of the reference tokens (hint: use the Counter class)
    ref_counter = None
    
    # initialize overlap to 0
    overlap = None
    
    # run a for loop over the sys_counter object (can be treated as a dictionary)
    for token in sys_counter:
        
        # lookup the value of the token in the sys_counter dictionary (hint: use the get() method)
        token_count_sys = None
        
        # lookup the value of the token in the ref_counter dictionary (hint: use the get() method)
        token_count_ref = None
        
        # update the overlap by getting the smaller number between the two token counts above
        overlap += None
    
    # get the precision (i.e. number of overlapping tokens / number of system tokens)
    precision = None
    
    # get the recall (i.e. number of overlapping tokens / number of reference tokens)
    recall = None
    
    if precision + recall != 0:
        # compute the f1-score
        rouge1_score = None
    else:
        rouge1_score = 0 
    ### END CODE HERE ###
    
    return rouge1_score
    


# In[ ]:


# notice that this produces a different value from the jaccard similarity earlier
rouge1_similarity([1, 2, 3], [1, 2, 3, 4])


# In[ ]:


# BEGIN UNIT TEST
w1_unittest.test_rouge1_similarity(rouge1_similarity)
# END UNIT TEST


# ### 4.2.3 Overall score
# 
# We will now build a function to generate the overall score for a particular sample. As mentioned earlier, we need to compare each sample with all other samples. For instance, if we generated 30 sentences, we will need to compare sentence 1 to sentences 2 to 30. Then, we compare sentence 2 to sentences 1 and 3 to 30, and so forth. At each step, we get the average score of all comparisons to get the overall score for a particular sample. To illustrate, these will be the steps to generate the scores of a 4-sample list.
# 
# 1. Get similarity score between sample 1 and sample 2
# 2. Get similarity score between sample 1 and sample 3
# 3. Get similarity score between sample 1 and sample 4
# 4. Get average score of the first 3 steps. This will be the overall score of sample 1.
# 5. Iterate and repeat until samples 1 to 4 have overall scores.
# 
# We will be storing the results in a dictionary for easy lookups.
# 
# <a name="ex09"></a>
# ### Exercise 09
# 
# **Instructions**: Implement the `average_overlap()` function.

# In[ ]:


# UNQ_C9
# GRADED FUNCTION
def average_overlap(similarity_fn, samples, *ignore_params):
    """Returns the arithmetic mean of each candidate sentence in the samples

    Args:
        similarity_fn (function): similarity function used to compute the overlap
        samples (list of lists): tokenized version of the translated sentences
        *ignore_params: additional parameters will be ignored

    Returns:
        dict: scores of each sample
            key: index of the sample
            value: score of the sample
    """  
    
    # initialize dictionary
    scores = {}
    
    # run a for loop for each sample
    for index_candidate, candidate in enumerate(samples):    
        
        ### START CODE HERE (REPLACE INSTANCES OF `None` WITH YOUR CODE) ###
        
        # initialize overlap to 0.0
        overlap = None
        
        # run a for loop for each sample
        for index_sample, sample in enumerate(samples): 

            # skip if the candidate index is the same as the sample index
            if index_candidate == index_sample:
                continue
                
            # get the overlap between candidate and sample using the similarity function
            sample_overlap = None
            
            # add the sample overlap to the total overlap
            overlap += None
            
        # get the score for the candidate by computing the average
        score = None
        
        # save the score in the dictionary. use index as the key.
        scores[index_candidate] = None
        
        ### END CODE HERE ###
    return scores


# In[ ]:


average_overlap(jaccard_similarity, [[1, 2, 3], [1, 2, 4], [1, 2, 4, 5]], [0.4, 0.2, 0.5])


# In[ ]:


# BEGIN UNIT TEST
w1_unittest.test_average_overlap(average_overlap)
# END UNIT TEST


# In practice, it is also common to see the weighted mean being used to calculate the overall score instead of just the arithmetic mean. We have implemented it below and you can use it in your experiements to see which one will give better results.

# In[ ]:


def weighted_avg_overlap(similarity_fn, samples, log_probs):
    """Returns the weighted mean of each candidate sentence in the samples

    Args:
        samples (list of lists): tokenized version of the translated sentences
        log_probs (list of float): log probability of the translated sentences

    Returns:
        dict: scores of each sample
            key: index of the sample
            value: score of the sample
    """
    
    # initialize dictionary
    scores = {}
    
    # run a for loop for each sample
    for index_candidate, candidate in enumerate(samples):    
        
        # initialize overlap and weighted sum
        overlap, weight_sum = 0.0, 0.0
        
        # run a for loop for each sample
        for index_sample, (sample, logp) in enumerate(zip(samples, log_probs)):

            # skip if the candidate index is the same as the sample index            
            if index_candidate == index_sample:
                continue
                
            # convert log probability to linear scale
            sample_p = float(np.exp(logp))

            # update the weighted sum
            weight_sum += sample_p

            # get the unigram overlap between candidate and sample
            sample_overlap = similarity_fn(candidate, sample)
            
            # update the overlap
            overlap += sample_p * sample_overlap
            
        # get the score for the candidate
        score = overlap / weight_sum
        
        # save the score in the dictionary. use index as the key.
        scores[index_candidate] = score
    
    return scores


# In[ ]:


weighted_avg_overlap(jaccard_similarity, [[1, 2, 3], [1, 2, 4], [1, 2, 4, 5]], [0.4, 0.2, 0.5])


# ### 4.2.4 Putting it all together
# 
# We will now put everything together and develop the `mbr_decode()` function. Please use the helper functions you just developed to complete this. You will want to generate samples, get the score for each sample, get the highest score among all samples, then detokenize this sample to get the translated sentence.
# 
# <a name="ex10"></a>
# ### Exercise 10
# 
# **Instructions**: Implement the `mbr_overlap()` function.

# In[ ]:


# UNQ_C10
# GRADED FUNCTION
def mbr_decode(sentence, n_samples, score_fn, similarity_fn, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None):
    """Returns the translated sentence using Minimum Bayes Risk decoding

    Args:
        sentence (str): sentence to translate.
        n_samples (int): number of samples to generate
        score_fn (function): function that generates the score for each sample
        similarity_fn (function): function used to compute the overlap between a pair of samples
        NMTAttn (tl.Serial): An LSTM sequence-to-sequence model with attention.
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)
        vocab_file (str): filename of the vocabulary
        vocab_dir (str): path to the vocabulary file

    Returns:
        str: the translated sentence
    """
    
    ### START CODE HERE (REPLACE INSTANCES OF `None` WITH YOUR CODE) ###
    # generate samples
    samples, log_probs = None
    
    # use the scoring function to get a dictionary of scores
    # pass in the relevant parameters as shown in the function definition of 
    # the mean methods you developed earlier
    scores = None
    
    # find the key with the highest score
    max_index = None
    
    # detokenize the token list associated with the max_index
    translated_sentence = None
    
    ### END CODE HERE ###
    return (translated_sentence, max_index, scores)


# In[ ]:


TEMPERATURE = 1.0

# put a custom string here
your_sentence = 'She speaks English and German.'


# In[ ]:


mbr_decode(your_sentence, 4, weighted_avg_overlap, jaccard_similarity, model, TEMPERATURE, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)[0]


# In[ ]:


mbr_decode('Congratulations!', 4, average_overlap, rouge1_similarity, model, TEMPERATURE, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)[0]


# In[ ]:


mbr_decode('You have completed the assignment!', 4, average_overlap, rouge1_similarity, model, TEMPERATURE, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)[0]


# **This unit test take a while to run. Please be patient**

# In[ ]:


# BEGIN UNIT TEST
w1_unittest.test_mbr_decode(mbr_decode, model)
# END UNIT TEST


# #### Congratulations! Next week, you'll dive deeper into attention models and study the Transformer architecture. You will build another network but without the recurrent part. It will show that attention is all you need! It should be fun!


Neural Machine Translation: Training the Model

Training Our Model

In the previous post we defined our model for machine translation. In this post we'll train the model on our data.

Doing supervised training in Trax is pretty straightforward (short example here). We will be instantiating three classes for this: TrainTask, EvalTask, and Loop. Let's take a closer look at each of these in the sections below.

Imports

# python
from collections import namedtuple
from contextlib import redirect_stdout
from functools import partial
from pathlib import Path

import sys

# pypi
from holoviews import opts
from trax import layers, optimizers
from trax.supervised import lr_schedules, training

import holoviews
import hvplot.pandas
import pandas

# this project
from neurotic.nlp.machine_translation import DataGenerator, NMTAttn

# related
from graeae import EmbedHoloviews, Timer

Set Up

train_batch_stream = DataGenerator().batch_generator
eval_batch_stream = DataGenerator(training=False).batch_generator
SLUG = "neural-machine-translation-training-the-model"
Embed = partial(EmbedHoloviews, folder_path=f"files/posts/nlp/{SLUG}")

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )
TIMER = Timer()

Training

TrainTask

The TrainTask class allows us to define the labeled data to use for training and the feedback mechanisms to compute the loss and update the weights.

train_task = training.TrainTask(

    # use the train batch stream as labeled data
    labeled_data = train_batch_stream,

    # use the cross entropy loss
    loss_layer = layers.WeightedCategoryCrossEntropy(),

    # use the Adam optimizer with learning rate of 0.01
    optimizer = optimizers.Adam(0.01),

    # use the `trax.lr.warmup_and_rsqrt_decay` as the learning rate schedule
    # have 1000 warmup steps with a max value of 0.01
    lr_schedule = lr_schedules.warmup_and_rsqrt_decay(1000, 0.01),

    # have a checkpoint every 10 steps
    n_steps_per_checkpoint= 10,
)
def test_train_task(train_task):
    target = train_task
    success = 0
    fails = 0

    # Test the labeled data parameter
    try:
        strlabel = str(target._labeled_data)
        assert(strlabel.find("generator") and strlabel.find('add_loss_weights'))
        success += 1
    except:
        fails += 1
        print("Wrong labeled data parameter")

    # Test the cross entropy loss data parameter
    try:
        strlabel = str(target._loss_layer)
        assert(strlabel == "CrossEntropyLoss_in3")
        success += 1
    except:
        fails += 1
        print("Wrong loss functions. CrossEntropyLoss_in3 was expected")

     # Test the optimizer parameter
    try:
        assert(isinstance(target.optimizer, trax.optimizers.adam.Adam))
        success += 1
    except:
        fails += 1
        print("Wrong optimizer")

    # Test the schedule parameter
    try:
        assert(isinstance(target._lr_schedule,trax.supervised.lr_schedules._BodyAndTail))
        success += 1
    except:
        fails += 1
        print("Wrong learning rate schedule type")

    # Test the _n_steps_per_checkpoint parameter
    try:
        assert(target._n_steps_per_checkpoint==10)
        success += 1
    except:
        fails += 1
        print("Wrong checkpoint step frequency")

    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
    return
test_train_task(train_task)
Wrong loss functions. CrossEntropyLoss_in3 was expected
Wrong optimizer
Wrong learning rate schedule type
[92m 2  Tests passed
[91m 3  Tests failed

The code has changed a bit since the test was written so it won't pass without updates.

EvalTask

The EvalTask on the other hand allows us to see how the model is doing while training. For our application, we want it to report the cross entropy loss and accuracy.

eval_task = training.EvalTask(

    ## use the eval batch stream as labeled data
    labeled_data=eval_batch_stream,

    ## use the cross entropy loss and accuracy as metrics
    metrics=[layers.WeightedCategoryCrossEntropy(), layers.Accuracy()],
)

Loop

The Loop class defines the model we will train as well as the train and eval tasks to execute. Its run() method allows us to execute the training for a specified number of steps.

output_dir = Path("~/models/machine_translation/").expanduser()

Define the training loop.

training_loop = training.Loop(NMTAttn(mode='train'),
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)
train_steps = 1000

with TIMER, \
     open("/tmp/machine_translation_training.log", "w") as temp_file, \
     redirect_stdout(temp_file):
            training_loop.run(train_steps)
Started: 2021-03-09 18:31:58.844878
Ended: 2021-03-09 20:14:43.090358
Elapsed: 1:42:44.245480
frame = pandas.DataFrame(
    training_loop.history.get("eval", "metrics/WeightedCategoryCrossEntropy"),
    columns="Batch CrossEntropy".split())

minimum = frame.loc[frame.CrossEntropy.idxmin()]
vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
hline = holoviews.HLine(minimum.CrossEntropy).opts(opts.HLine(color=PLOT.red))
line = frame.hvplot(x="Batch", y="CrossEntropy").opts(opts.Curve(color=PLOT.blue))

plot = (line * hline * vline).opts(
    width=PLOT.width, height=PLOT.height,
    title="Evaluation Batch Cross Entropy Loss",
                                   )
output = Embed(plot=plot, file_name="evaluation_cross_entropy")()
print(output)

Figure Missing

frame = pandas.DataFrame(
    training_loop.history.get("eval", "metrics/Accuracy"),
    columns="Batch Accuracy".split())

minimum = frame.loc[frame.Accuracy.idxmin()]
vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
hline = holoviews.HLine(minimum.Accuracy).opts(opts.HLine(color=PLOT.red))
line = frame.hvplot(x="Batch", y="Accuracy").opts(opts.Curve(color=PLOT.blue))

plot = (line * hline * vline).opts(
    width=PLOT.width, height=PLOT.height,
    title="Evaluation Batch Accuracy",
                                   )
output = Embed(plot=plot, file_name="evaluation_accuracy")()
print(output)

Figure Missing

It seems to be stuck…

End

Now that we've trained the model in the next post we'll test our model to see how well it does. The overview post with links to all the posts in this series is here.

Raw

Neural Machine Translation: The Attention Model

<<imports>>

<<attention-model>>

Defining the Model

In the previous post we made some helper functions to prepare inputs for some of the layers in the model. In this post we'll define the model itself.

Attention Overview

The model we will be building uses an encoder-decoder architecture. This Recurrent Neural Network (RNN) will take in a tokenized version of a sentence in its encoder, then passes it on to the decoder for translation. Just using a a regular sequence-to-sequence model with LSTMs will work effectively for short to medium sentences but will start to degrade for longer ones. You can picture it like the figure below where all of the context of the input sentence is compressed into one vector that is passed into the decoder block. You can see how this will be an issue for very long sentences (e.g. 100 tokens or more) because the context of the first parts of the input will have very little effect on the final vector passed to the decoder.

Adding an attention layer to this model avoids this problem by giving the decoder access to all parts of the input sentence. To illustrate, let's just use a 4-word input sentence as shown below. Remember that a hidden state is produced at each timestep of the encoder (represented by the orange rectangles). These are all passed to the attention layer and each are given a score given the current activation (i.e. hidden state) of the decoder. For instance, let's consider the figure below where the first prediction "Wie" is already made. To produce the next prediction, the attention layer will first receive all the encoder hidden states (i.e. orange rectangles) as well as the decoder hidden state when producing the word "Wie" (i.e. first green rectangle). Given this information, it will score each of the encoder hidden states to know which one the decoder should focus on to produce the next word. The result of the model training might have learned that it should align to the second encoder hidden state and subsequently assigns a high probability to the word "geht". If we are using greedy decoding, we will output the said word as the next symbol, then restart the process to produce the next word until we reach an end-of-sentence prediction.

There are different ways to implement attention and the one we'll use is the Scaled Dot Product Attention which has the form:

\[ Attention(Q, K, V) = softmax \left(\frac{QK^T}{\sqrt{d_k}} \right)V \]

You can think of it as computing scores using queries (Q) and keys (K), followed by a multiplication of values (V) to get a context vector at a particular timestep of the decoder. This context vector is fed to the decoder RNN to get a set of probabilities for the next predicted word. The division by square root of the keys dimensionality (\(\sqrt{d_k}\)) is for improving model performance and you'll also learn more about it next week. For our machine translation application, the encoder activations (i.e. encoder hidden states) will be the keys and values, while the decoder activations (i.e. decoder hidden states) will be the queries.

You will see in the upcoming sections that this complex architecture and mechanism can be implemented with just a few lines of code.

Imports

# pypi
from trax import layers

import trax

# this project
from neurotic.nlp.machine_translation import (
    NMTAttn)

Implementation

Overview

We are now ready to implement our sequence-to-sequence model with attention. This will be a Serial network and is illustrated in the diagram below. It shows the layers you'll be using in Trax and you'll see that each step can be implemented quite easily with one line commands. We've placed several links to the documentation for each relevant layer in the discussion after the figure below.

  • Step 0: Prepare the input encoder and pre-attention decoder branches. We've already defined this earlier as helper functions so it's just a matter of calling those functions and assigning it to variables.
  • Step 1: Create a Serial network. This will stack the layers in the next steps one after the other. As before, we'll use tl.Serial.
  • Step 2: Make a copy of the input and target tokens. As you see in the diagram above, the input and target tokens will be fed into different layers of the model. We'll use tl.Select layer to create copies of these tokens, arranging them as [input tokens, target tokens, input tokens, target tokens].
  • Step 3: Create a parallel branch to feed the input tokens to the input_encoder and the target tokens to the pre_attention_decoder. We'll use tl.Parallel to create these sublayers in parallel, remembering to pass the variables defined in Step 0 as parameters to this layer.
  • Step 4: Next, call the `prepare_attention_input` function to convert the encoder and pre-attention decoder activations to a format that the attention layer will accept. You can use tl.Fn to call this function. Note: Pass the prepare_attention_input function as the f parameter in tl.Fn without any arguments or parenthesis.
  • Step 5: We will now feed the (queries, keys, values, and mask) to the tl.AttentionQKV layer. This computes the scaled dot product attention and outputs the attention weights and mask. Take note that although it is a one liner, this layer is actually composed of a deep network made up of several branches. We'll show the implementation show here (on github) to see the different layers used.
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Returns a layer that maps (q, k, v, mask) to (activations, mask).

  See `Attention` above for further context/details.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or 'eval'.
  """
  return cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(  # pylint: disable=no-value-for-parameter
          n_heads=n_heads, dropout=dropout, mode=mode),
      core.Dense(d_feature),
  )

Having deep layers poses the risk of vanishing gradients during training and we would want to mitigate that. To improve the ability of the network to learn, we can insert a tl.Residual layer to add the output of AttentionQKV with the queries input. You can do this in trax by simply nesting the AttentionQKV layer inside the Residual layer. The library will take care of branching and adding for you.

  • Step 6: We will not need the mask for the model we're building so we can safely drop it. At this point in the network, the signal stack currently has [attention activations, mask, target tokens] and you can use tl.Select to output just [attention activations, target tokens].
  • Step 7: We can now feed the attention weighted output to the LSTM decoder. We can stack multiple tl.LSTM layers to improve the output so remember to append LSTMs equal to the number defined by n_decoder_layers parameter to the model.
  • Step 8: We want to determine the probabilities of each subword in the vocabulary and you can set this up easily with a tl.Dense layer by making its size equal to the size of our vocabulary.
  • Step 9: Normalize the output to log probabilities by passing the activations in Step 8 to a tl.LogSoftmax layer.

The Implementation

# pypi
from trax import layers

# this project
from .help_me import input_encoder as input_encoder_fn
from .help_me import pre_attention_decoder as pre_attention_decoder_fn
from .help_me import prepare_attention_input as prepare_attention_input_fn
def NMTAttn(input_vocab_size: int=33300,
            target_vocab_size: int=33300,
            d_model: int=1024,
            n_encoder_layers: int=2,
            n_decoder_layers: int=2,
            n_attention_heads: int=4,
            attention_dropout: float=0.0,
            mode: str='train') -> layers.Serial:
    """Returns an LSTM sequence-to-sequence model with attention.

    The input to the model is a pair (input tokens, target tokens), e.g.,
    an English sentence (tokenized) and its translation into German (tokenized).

    Args:
    input_vocab_size: int: vocab size of the input
    target_vocab_size: int: vocab size of the target
    d_model: int:  depth of embedding (n_units in the LSTM cell)
    n_encoder_layers: int: number of LSTM layers in the encoder
    n_decoder_layers: int: number of LSTM layers in the decoder after attention
    n_attention_heads: int: number of attention heads
    attention_dropout: float, dropout for the attention layer
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

    Returns:
    A LSTM sequence-to-sequence model with attention.
    """
    # Step 0: call the helper function to create layers for the input encoder
    input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers)

    # Step 0: call the helper function to create layers for the pre-attention decoder
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model)

    # Step 1: create a serial network
    model = layers.Serial( 

      # Step 2: copy input tokens and target tokens as they will be needed later.
      layers.Select([0, 1, 0, 1]),

      # Step 3: run input encoder on the input and pre-attention decoder on the target.
      layers.Parallel(input_encoder, pre_attention_decoder),

      # Step 4: prepare queries, keys, values and mask for attention.
      layers.Fn('PrepareAttentionInput', prepare_attention_input_fn, n_out=4),

      # Step 5: run the AttentionQKV layer
      # nest it inside a Residual layer to add to the pre-attention decoder activations(i.e. queries)
      layers.Residual(layers.AttentionQKV(d_model,
                                          n_heads=n_attention_heads,
                                          dropout=attention_dropout, mode=mode)),

      # Step 6: drop attention mask (i.e. index = None
      layers.Select([0, 2]),

      # Step 7: run the rest of the RNN decoder
      [layers.LSTM(d_model) for _ in range(n_decoder_layers)],

      # Step 8: prepare output by making it the right size
      layers.Dense(target_vocab_size),

      # Step 9: Log-softmax for output
      layers.LogSoftmax()
    )
    return model
def test_NMTAttn(NMTAttn):
    test_cases = [
                {
                    "name":"simple_test_check",
                    "expected":"Serial_in2_out2[\n  Select[0,1,0,1]_in2_out4\n  Parallel_in2_out2[\n    Serial[\n      Embedding_33300_1024\n      LSTM_1024\n      LSTM_1024\n    ]\n    Serial[\n      ShiftRight(1)\n      Embedding_33300_1024\n      LSTM_1024\n    ]\n  ]\n  PrepareAttentionInput_in3_out4\n  Serial_in4_out2[\n    Branch_in4_out3[\n      None\n      Serial_in4_out2[\n        Parallel_in3_out3[\n          Dense_1024\n          Dense_1024\n          Dense_1024\n        ]\n        PureAttention_in4_out2\n        Dense_1024\n      ]\n    ]\n    Add_in2\n  ]\n  Select[0,2]_in3_out2\n  LSTM_1024\n  LSTM_1024\n  Dense_33300\n  LogSoftmax\n]",
                    "error":"The NMTAttn is not defined properly."
                },
                {
                    "name":"layer_len_check",
                    "expected":9,
                    "error":"We found {} layers in your model. It should be 9.\nCheck the LSTM stack before the dense layer"
                },
                {
                    "name":"selection_layer_check",
                    "expected":["Select[0,1,0,1]_in2_out4", "Select[0,2]_in3_out2"],
                    "error":"Look at your selection layers."
                }
            ]

    success = 0
    fails = 0

    for test_case in test_cases:
        try:
            if test_case['name'] == "simple_test_check":
                assert test_case["expected"] == str(NMTAttn())
                success += 1
            if test_case['name'] == "layer_len_check":
                if test_case["expected"] == len(NMTAttn().sublayers):
                    success += 1
                else:
                    print(test_case["error"].format(len(NMTAttn().sublayers))) 
                    fails += 1
            if test_case['name'] == "selection_layer_check":
                model = NMTAttn()
                output = [str(model.sublayers[0]),str(model.sublayers[4])]
                check_count = 0
                for i in range(2):
                    if test_case["expected"][i] != output[i]:
                        print(test_case["error"])
                        fails += 1
                        break
                    else:
                        check_count += 1
                if check_count == 2:
                    success += 1
        except:
            print(test_case['error'])
            fails += 1

    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
    return test_cases
test_cases = test_NMTAttn(NMTAttn)
The NMTAttn is not defined properly.
[92m 2  Tests passed
[91m 1  Tests failed
model = NMTAttn()
print(model)
Serial_in2_out2[
  Select[0,1,0,1]_in2_out4
  Parallel_in2_out2[
    Serial[
      Embedding_33300_1024
      LSTM_1024
      LSTM_1024
    ]
    Serial[
      Serial[
        ShiftRight(1)
      ]
      Embedding_33300_1024
      LSTM_1024
    ]
  ]
  PrepareAttentionInput_in3_out4
  Serial_in4_out2[
    Branch_in4_out3[
      None
      Serial_in4_out2[
        _in4_out4
        Serial_in4_out2[
          Parallel_in3_out3[
            Dense_1024
            Dense_1024
            Dense_1024
          ]
          PureAttention_in4_out2
          Dense_1024
        ]
        _in2_out2
      ]
    ]
    Add_in2
  ]
  Select[0,2]_in3_out2
  LSTM_1024
  LSTM_1024
  Dense_33300
  LogSoftmax
]

End

Now that we have the model defined, in the next post we'll train the model. The overview post with links to all the posts in this series is here.

Neural Machine Translation: The Data

The Data

This is the first post in a series that will look at creating a Long-Short-Term-Memory (LSTM) model with attention for Machine Learning. The previous post was an overview that holds the links to all the posts in the series.

Imports

# python
from pathlib import Path

import random

# pypi
from termcolor import colored

import numpy
import trax

Middle

Loading the Data

Next, we will import the dataset we will use to train the model. If you are running out of space, you can just use a small dataset from Opus, a growing collection of translated texts from the web. Particularly, we will get an English to German translation subset specified as opus/medical which has medical related texts. If storage is not an issue, you can opt to get a larger corpus such as the English to German translation dataset from ParaCrawl, a large multi-lingual translation dataset created by the European Union. Both of these datasets are available via Tensorflow Datasets (TFDS) and you can browse through the other available datasets here. As you'll see below, you can easily access this dataset from TFDS with trax.data.TFDS. The result is a python generator function yielding tuples. Use the keys argument to select what appears at which position in the tuple. For example, keys=('en', 'de') below will return pairs as (English sentence, German sentence).

The para_crawl/ende dataset is 4.04 GiB while the opus/medical dataset is 188.85 MiB.

Note: Trying to download the ParaCrawl dataset using trax creates an out of resource error. You can try downloading the source from:

https://s3.amazonaws.com/web-language-models/paracrawl/release4/en-de.bicleaner07.txt.gz

Although I haven't figured out how to get it into the trax data yet so I'm sticking with the smaller data set.

The Training Data

The first time you run this it will download the dataset, after that it will just load it from the file.

path = Path("~/data/tensorflow/translation/").expanduser()

data_set = "opus/medical"
# data_set = "para_crawl/ende"

train_stream_fn = trax.data.TFDS(data_set,
                                 data_dir=path,
                                 keys=('en', 'de'),
                                 eval_holdout_size=0.01,
                                 train=True)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-fb62d04026f5> in <module>
      4 # data_set = "para_crawl/ende"
      5 
----> 6 train_stream_fn = trax.data.TFDS(data_set,
      7                                  data_dir=path,
      8                                  keys=('en', 'de'),

/usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
   1067       scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
   1068       err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1069       utils.augment_exception_message_and_reraise(e, err_str)
   1070 
   1071   return gin_wrapper

/usr/local/lib/python3.8/dist-packages/gin/utils.py in augment_exception_message_and_reraise(exception, message)
     39   proxy = ExceptionProxy()
     40   ExceptionProxy.__qualname__ = type(exception).__qualname__
---> 41   raise proxy.with_traceback(exception.__traceback__) from None
     42 
     43 

/usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
   1044 
   1045     try:
-> 1046       return fn(*new_args, **new_kwargs)
   1047     except Exception as e:  # pylint: disable=broad-except
   1048       err_str = ''

/usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
   1067       scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
   1068       err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1069       utils.augment_exception_message_and_reraise(e, err_str)
   1070 
   1071   return gin_wrapper

/usr/local/lib/python3.8/dist-packages/gin/utils.py in augment_exception_message_and_reraise(exception, message)
     39   proxy = ExceptionProxy()
     40   ExceptionProxy.__qualname__ = type(exception).__qualname__
---> 41   raise proxy.with_traceback(exception.__traceback__) from None
     42 
     43 

/usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
   1044 
   1045     try:
-> 1046       return fn(*new_args, **new_kwargs)
   1047     except Exception as e:  # pylint: disable=broad-except
   1048       err_str = ''

~/trax/trax/data/tf_inputs.py in TFDS(dataset_name, data_dir, tfds_preprocess_fn, keys, train, shuffle_train, host_id, n_hosts, eval_holdout_size)
    279   else:
    280     subsplit = None
--> 281   (train_data, eval_data, _) = _train_and_eval_dataset(
    282       dataset_name, data_dir, eval_holdout_size,
    283       train_shuffle_files=shuffle_train, subsplit=subsplit)

~/trax/trax/data/tf_inputs.py in _train_and_eval_dataset(dataset_name, data_dir, eval_holdout_size, train_shuffle_files, eval_shuffle_files, subsplit)
    224   if eval_holdout_examples > 0 or subsplit is not None:
    225     n_train = train_examples - eval_holdout_examples
--> 226     train_start = int(n_train * subsplit[0])
    227     train_end = int(n_train * subsplit[1])
    228     if train_end - train_start < 1:

TypeError: 'NoneType' object is not subscriptable
  In call to configurable 'TFDS' (<function TFDS at 0x7f960c527280>)
  In call to configurable 'TFDS' (<function TFDS at 0x7f960c526f70>)

The Evaluation Data

Since we already downloaded the data in the previous code-block, this will just load the evaluation set from the downloaded data.

eval_stream_fn = trax.data.TFDS('opus/medical',
                                data_dir=path,
                                keys=('en', 'de'),
                                eval_holdout_size=0.01,
                                train=False)

Notice that TFDS returns a generator function, not a generator. This is because in Python, you cannot reset generators so you cannot go back to a previously yielded value. During deep learning training, you use Stochastic Gradient Descent and don't actually need to go back – but it is sometimes good to be able to do that, and that's where the functions come in. Let's print a a sample pair from our train and eval data. Notice that the raw output is represented in bytes (denoted by the b' prefix) and these will be converted to strings internally in the next steps.

train_stream = train_stream_fn()
print(colored('train data (en, de) tuple:', 'red'), next(train_stream))
print()
[31mtrain data (en, de) tuple:[0m (b'Tel: +421 2 57 103 777\n', b'Tel: +421 2 57 103 777\n')

eval_stream = eval_stream_fn()
print(colored('eval data (en, de) tuple:', 'red'), next(eval_stream))
[31meval data (en, de) tuple:[0m (b'Lutropin alfa Subcutaneous use.\n', b'Pulver zur Injektion Lutropin alfa Subkutane Anwendung\n')

Tokenization and Formatting

Now that we have imported our corpus, we will be preprocessing the sentences into a format that our model can accept. This will be composed of several steps:

Tokenizing the sentences using subword representations: We want to represent each sentence as an array of integers instead of strings. For our application, we will use subword representations to tokenize our sentences. This is a common technique to avoid out-of-vocabulary words by allowing parts of words to be represented separately. For example, instead of having separate entries in your vocabulary for –"fear", "fearless", "fearsome", "some", and "less"–, you can simply store –"fear", "some", and "less"– then allow your tokenizer to combine these subwords when needed. This allows it to be more flexible so you won't have to save uncommon words explicitly in your vocabulary (e.g. stylebender, nonce, etc). Tokenizing is done with the `trax.data.Tokenize()` command and we have provided you the combined subword vocabulary for English and German (i.e. `ende_32k.subword`) retrieved from https://storage.googleapis.com/trax-ml/vocabs/ende_32k.subword (I'm using the web-interface, but you could also just download it and put it in a directory).

VOCAB_FILE = 'ende_32k.subword'
VOCAB_DIR = "gs://trax-ml/vocabs/" # google storage

# Tokenize the dataset.
tokenized_train_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)
tokenized_eval_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)

Append an end-of-sentence token to each sentence: We will assign a token (i.e. in this case 1) to mark the end of a sentence. This will be useful in inference/prediction so we'll know that the model has completed the translation.

Integer assigned as end-of-sentence (EOS)

EOS = 1
def append_eos(stream):
    """helper to add end of sentence token to sentences in the stream

    Yields:
     next tuple of numpy arrays with EOS token added (inputs, targets)
    """
    for (inputs, targets) in stream:
        inputs_with_eos = list(inputs) + [EOS]
        targets_with_eos = list(targets) + [EOS]
        yield numpy.array(inputs_with_eos), numpy.array(targets_with_eos)
    return
tokenized_train_stream = append_eos(tokenized_train_stream)
tokenized_eval_stream = append_eos(tokenized_eval_stream)

Filter long sentences

We will place a limit on the number of tokens per sentence to ensure we won't run out of memory. This is done with the trax.data.FilterByLength() method and you can see its syntax below.

Filter too long sentences to not run out of memory. length_keys=[0, 1] means we filter both English and German sentences, so both must not be longer that 256 tokens for training and 512 tokens for evaluation.

filtered_train_stream = trax.data.FilterByLength(
    max_length=256, length_keys=[0, 1])(tokenized_train_stream)
filtered_eval_stream = trax.data.FilterByLength(
    max_length=512, length_keys=[0, 1])(tokenized_eval_stream)
train_input, train_target = next(filtered_train_stream)
print(colored(f'Single tokenized example input:', 'red' ), train_input)
print(colored(f'Single tokenized example target:', 'red'), train_target)
[31mSingle tokenized example input:[0m [ 2538  2248    30 12114 23184 16889     5     2 20852  6456 20592  5812
  3932    96  5178  3851    30  7891  3550 30650  4729   992     1]
[31mSingle tokenized example target:[0m [ 1872    11  3544    39  7019 17877 30432    23  6845    10 14222    47
  4004    18 21674     5 27467  9513   920   188 10630    18  3550 30650
  4729   992     1]

tokenize & detokenize helper functions

Given any data set, you have to be able to map words to their indices, and indices to their words. The inputs and outputs to your trax models are usually tensors of numbers where each number corresponds to a word. If you were to process your data manually, you would have to make use of the following:

  • word2Ind: a dictionary mapping the word to its index.
  • ind2Word: a dictionary mapping the index to its word.
  • word2Count: a dictionary mapping the word to the number of times it appears.
  • num_words: total number of words that have appeared.
def tokenize(input_str: str,
             vocab_file: str=None, vocab_dir: str=None, EOS: int=EOS) -> numpy.ndarray:
    """Encodes a string to an array of integers

    Args:
       input_str: human-readable string to encode
       vocab_file: filename of the vocabulary text file
       vocab_dir: path to the vocabulary file

    Returns:
       tokenized version of the input string
    """
    # Use the trax.data.tokenize method. It takes streams and returns streams,
    # we get around it by making a 1-element stream with `iter`.
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_file=vocab_file,
                                      vocab_dir=vocab_dir))

    # Mark the end of the sentence with EOS
    inputs = list(inputs) + [EOS]

    # Adding the batch dimension to the front of the shape
    batch_inputs = numpy.reshape(numpy.array(inputs), [1, -1])

    return batch_inputs
def detokenize(integers: numpy.ndarray,
               vocab_file: str=None,
               vocab_dir: str=None,
               EOS: int=EOS) -> str:
    """Decodes an array of integers to a human readable string

    Args:
       integers: array of integers to decode
       vocab_file: filename of the vocabulary text file
       vocab_dir: path to the vocabulary file

    Returns:
       str: the decoded sentence.
    """
    # Remove the dimensions of size 1
    integers = list(numpy.squeeze(integers))

    # Remove the EOS to decode only the original tokens
    if EOS in integers:
        integers = integers[:integers.index(EOS)] 

    return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)

Let's see how we might use these functions:

Detokenize an input-target pair of tokenized sentences

print(colored(f'Single detokenized example input:', 'red'), detokenize(train_input, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))
print(colored(f'Single detokenized example target:', 'red'), detokenize(train_target, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))
print()
[31mSingle detokenized example input:[0m During treatment with olanzapine, adolescents gained significantly more weight compared with adults.

[31mSingle detokenized example target:[0m Während der Behandlung mit Olanzapin nahmen die Jugendlichen im Vergleich zu Erwachsenen signifikant mehr Gewicht zu.

Tokenize and detokenize a word that is not explicitly saved in the vocabulary file. See how it combines the subwords – 'hell' and 'o'– to form the word 'hello'.

print(colored("tokenize('hello'): ", 'green'), tokenize('hello', vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))
print(colored("detokenize([17332, 140, 1]): ", 'green'), detokenize([17332, 140, 1], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))
[32mtokenize('hello'): [0m [[17332   140     1]]
[32mdetokenize([17332, 140, 1]): [0m hello

Bucketing

Bucketing the tokenized sentences is an important technique used to speed up training in NLP. Here is a nice article describing it in detail but the gist is very simple. Our inputs have variable lengths and you want to make these the same when batching groups of sentences together. One way to do that is to pad each sentence to the length of the longest sentence in the dataset. This might lead to some wasted computation though. For example, if there are multiple short sentences with just two tokens, do we want to pad these when the longest sentence is composed of a 100 tokens? Instead of padding with 0s to the maximum length of a sentence each time, we can group our tokenized sentences by length and bucket.

We batch the sentences with similar length together and only add minimal padding to make them have equal length (usually up to the nearest power of two). This allows us to waste less computation when processing padded sequences.

In Trax, it is implemented in the bucket_by_length function.

Bucketing to create streams of batches.

Buckets are defined in terms of boundaries and batch sizes. Batch_sizes[i] determines the batch size for items with length < boundaries[i]. So below, we'll take a batch of 256 sentences of length < 8, 128 if length is between 8 and 16, and so on – and only 2 if length is over 512. We'll do the bucketing using bucket_by_length.

boundaries = [2**power_of_two for power_of_two in range(3, 10)]
batch_sizes = [2**power_of_two for power_of_two in range(8, 0, -1)]

Create the generators.

train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]  # As before: count inputs and targets to length.
)(filtered_train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]
)(filtered_eval_stream)

Add masking for the padding (0s) using add_loss_weights (we're using AddLossWeights but the documentation for that just says "see add_loss_weights"). I can't find any documentation for it, but I think the 0's are what BucketByLength uses for padding.

train_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)
eval_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(eval_batch_stream)

Exploring the data

We will now be displaying some of our data. You will see that the functions defined above (i.e. tokenize() and detokenize()) do the same things you have been doing again and again throughout the specialization. We gave these so you can focus more on building the model from scratch. Let us first get the data generator and get one batch of the data.

input_batch, target_batch, mask_batch = next(train_batch_stream)

Let's see the data type of a batch.

print("input_batch data type: ", type(input_batch))
print("target_batch data type: ", type(target_batch))
input_batch data type:  <class 'numpy.ndarray'>
target_batch data type:  <class 'numpy.ndarray'>

Let's see the shape of this particular batch (batch length, sentence length).

print("input_batch shape: ", input_batch.shape)
print("target_batch shape: ", target_batch.shape)
input_batch shape:  (32, 64)
target_batch shape:  (32, 64)

The input_batch and target_batch are Numpy arrays consisting of tokenized English sentences and German sentences respectively. These tokens will later be used to produce embedding vectors for each word in the sentence (so the embedding for a sentence will be a matrix). The number of sentences in each batch is usually a power of 2 for optimal computer memory usage.

We can now visually inspect some of the data. You can run the cell below several times to shuffle through the sentences. Just to note, while this is a standard data set that is used widely, it does have some known wrong translations. With that, let's pick a random sentence and print its tokenized representation.

Pick a random index less than the batch size.

index = random.randrange(len(input_batch))

Use the index to grab an entry from the input and target batch.

print(colored('THIS IS THE ENGLISH SENTENCE: \n', 'red'), detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \n ', 'red'), input_batch[index], '\n')
print(colored('THIS IS THE GERMAN TRANSLATION: \n', 'red'), detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: \n', 'red'), target_batch[index], '\n')
[31mTHIS IS THE ENGLISH SENTENCE: 
[0m Kidneys and urinary tract (no effects were found to be common); uncommon: blood in the urine, proteins in the urine, sugar in the urine; rare: urge to pass urine, kidney pain, passing urine frequently.
 

[31mTHIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: 
 [0m [ 5381 17607  3093     8  8670  6086   105 19166     5    50   154  1743
   152  1103     9    32   568  8076 19124  6847    64  6196     6     4
  8670   510     2 13355   823     6     4  8670   510     2  4968     6
     4  8670   510   115  7227    64  7628     9  2685  8670   510     2
 12220  5509 12095     2 19632  8670   510  7326  3550 30650  4729   992
     1     0     0     0] 

[31mTHIS IS THE GERMAN TRANSLATION: 
[0m Harndrang, Nierenschmerzen, häufiges Wasserlassen.
 

[31mTHIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: 
[0m [ 5135 14970  2920     2  6262  4594 27552    28     2 20052    33  3736
   530  3550 30650  4729   992     1     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0] 

Bundle it Up

Imports

# python
from collections import namedtuple
from pathlib import Path

# pypi
import attr
import numpy
import trax

Constants

DataDefaults = namedtuple("DataDefaults",
                          ["path",
                           "dataset",
                           "keys",
                           "evaluation_size",
                           "end_of_sentence",
                           "vocabulary_file",
                           "vocabulary_path",
                           "length_keys",
                           "boundaries",
                           "batch_sizes",
                           "padding_token"])

DEFAULTS = DataDefaults(
    path=Path("~/data/tensorflow/translation/").expanduser(),
    dataset="opus/medical",
    keys=("en", "de"),
    evaluation_size=0.01,
    end_of_sentence=1,
    vocabulary_file="ende_32k.subword",
    vocabulary_path="gs://trax-ml/vocabs/",
    length_keys=[0, 1],
    boundaries=[2**power_of_two for power_of_two in range(3, 10)],
    batch_sizes=[2**power_of_two for power_of_two in range(8, 0, -1)],
    padding_token=0,
)

MaxLength = namedtuple("MaxLength", "train evaluate".split())
MAX_LENGTH = MaxLength(train=256, evaluate=512)
END_OF_SENTENCE = 1

Tokenizer/Detokenizer

Tokenizer

def tokenize(input_str: str,
             vocab_file: str=None, vocab_dir: str=None,
             end_of_sentence: int=DEFAULTS.end_of_sentence) -> numpy.ndarray:
    """Encodes a string to an array of integers

    Args:
       input_str: human-readable string to encode
       vocab_file: filename of the vocabulary text file
       vocab_dir: path to the vocabulary file
       end_of_sentence: token for the end of sentence
    Returns:
       tokenized version of the input string
    """
    # The trax.data.tokenize method takes streams and returns streams,
    # we get around it by making a 1-element stream with `iter`.
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_file=vocab_file,
                                      vocab_dir=vocab_dir))

    # Mark the end of the sentence with EOS
    inputs = list(inputs) + [end_of_sentence]

    # Adding the batch dimension to the front of the shape
    batch_inputs = numpy.reshape(numpy.array(inputs), [1, -1])
    return batch_inputs

Detokenizer

def detokenize(integers: numpy.ndarray,
               vocab_file: str=None,
               vocab_dir: str=None,
               end_of_sentence: int=DEFAULTS.end_of_sentence) -> str:
    """Decodes an array of integers to a human readable string

    Args:
       integers: array of integers to decode
       vocab_file: filename of the vocabulary text file
       vocab_dir: path to the vocabulary file
       end_of_sentence: token to mark the end of a sentence
    Returns:
       str: the decoded sentence.
    """
    # Remove the dimensions of size 1
    integers = list(numpy.squeeze(integers))

    # Remove the EOS to decode only the original tokens
    if end_of_sentence in integers:
        integers = integers[:integers.index(end_of_sentence)] 

    return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)

Data Generator

@attr.s(auto_attribs=True)
class DataGenerator:
    """Generates the streams of data

    Args:
     training: whether this generates training data or not
     path: path to the data set
     data_set: name of the data set (from tensorflow datasets)
     keys: the names of the data
     max_length: longest allowed set of tokens
     evaluation_fraction: how much of the data is saved for evaluation
     length_keys: keys (indexes) to use when setting length
     boundaries: upper limits for batch sizes
     batch_sizes: batch_size for each boundary
     padding_token: which token is used for padding
     vocabulary_file: name of the sub-words vocabulary file
     vocabulary_path: where to find the vocabulary file
     end_of_sentence: token to indicate the end of a sentence
    """
    training: bool=True
    path: Path=DEFAULTS.path
    data_set: str=DEFAULTS.dataset
    keys: tuple=DEFAULTS.keys
    max_length: int=MAX_LENGTH.train
    length_keys: list=DEFAULTS.length_keys
    boundaries: list=DEFAULTS.boundaries
    batch_sizes: list=DEFAULTS.batch_sizes
    evaluation_fraction: float=DEFAULTS.evaluation_size
    vocabulary_file: str=DEFAULTS.vocabulary_file
    vocabulary_path: str=DEFAULTS.vocabulary_path
    padding_token: int=DEFAULTS.padding_token
    end_of_sentence: int=DEFAULTS.end_of_sentence
    _generator_function: type=None
    _batch_generator: type=None

Append End of Sentence

def end_of_sentence_generator(self, original):
    """Generator that adds end of sentence tokens

    Args:
     original: generator to add the end of sentence tokens to

    Yields:
     next tuple of arrays with EOS token added
    """
    for inputs, targets in original:
        inputs = list(inputs) + [self.end_of_sentence]
        targets = list(targets) + [self.end_of_sentence]
        yield numpy.array(inputs), numpy.array(targets)
    return 

Generator Function

@property
def generator_function(self):
    """Function to create the data generator"""
    if self._generator_function is None:
        self._generator_function = trax.data.TFDS(self.data_set,
                                                  data_dir=self.path,
                                                  keys=self.keys,
                                                  eval_holdout_size=self.evaluation_fraction,
                                                  train=self.training)
    return self._generator_function

Batch Stream

@property
def batch_generator(self):
    """batch data generator"""
    if self._batch_generator is None:
        generator = self.generator_function()
        generator = trax.data.Tokenize(
            vocab_file=self.vocabulary_file,
            vocab_dir=self.vocabulary_path)(generator)
        generator = self.end_of_sentence_generator(generator)
        generator = trax.data.FilterByLength(
            max_length=self.max_length,
            length_keys=self.length_keys)(generator)
        generator = trax.data.BucketByLength(
            self.boundaries, self.batch_sizes,
            length_keys=self.length_keys
        )(generator)
        self._batch_generator = trax.data.AddLossWeights(
            id_to_mask=self.padding_token)(generator)
    return self._batch_generator

Try It Out

from neurotic.nlp.machine_translation import DataGenerator, detokenize

generator = DataGenerator().batch_generator
input_batch, target_batch, mask_batch = next(generator)
index = random.randrange(len(batch))


print(colored('THIS IS THE ENGLISH SENTENCE: \n', 'red'), detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \n ', 'red'), input_batch[index], '\n')
print(colored('THIS IS THE GERMAN TRANSLATION: \n', 'red'), detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: \n', 'red'), target_batch[index], '\n')
[31mTHIS IS THE ENGLISH SENTENCE: 
[0m Signs of hypersensitivity reactions include hives, generalised urticaria, tightness of the chest, wheezing, hypotension and anaphylaxis.
 

[31mTHIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: 
 [0m [10495    14     7 10224 19366 10991  1020  3481  2486     2  9547  7417
   103  4572 11927  9371     2 13197  1496     7     4 24489    62     2
 16402 24010   211     2  4814 23010 12122    22     8  4867 19606  6457
  5175    14  3550 30650  4729   992     1     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0] 

[31mTHIS IS THE GERMAN TRANSLATION: 
[0m Überempfindlichkeitsreaktionen können sich durch Anzeichen wie Nesselausschlag, generalisierte Urtikaria, Engegefühl im Brustkorb, Pfeifatmung, Blutdruckabfall und Anaphylaxie äußern.
 

[31mTHIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: 
[0m [ 3916 29551 13504  5020  4094 13522   119    51   121  8602    93 31508
  6050 30327  6978     2  9547  7417  2446  5618  4581  5530  1384     2
 26006  7831 13651     5    47  8584  4076  5262   868     2 25389  8898
 28268     2  9208 29697 17944    83    12  9925 19606  6457 16384     5
 11790  3550 30650  4729   992     1     0     0     0     0     0     0
     0     0     0     0] 

End

Now that we have our data prepared it's time to move on to defining the Attention Model.

Neural Machine Translation

Neural Machine Translations

Here, we will build an English-to-German neural machine translation (NMT) model using Long Short-Term Memory (LSTM) networks with attention. Machine translation is an important task in natural language processing and could be useful not only for translating one language to another but also for word sense disambiguation (e.g. determining whether the word "bank" refers to the financial bank, or the land alongside a river). Implementing this using just a Recurrent Neural Network (RNN) with LSTMs can work for short to medium length sentences but can result in vanishing gradients for very long sequences. To solve this, we will be adding an attention mechanism to allow the decoder to access all relevant parts of the input sentence regardless of its length. By completing this assignment, we will:

  • learn how to preprocess your training and evaluation data
  • implement an encoder-decoder system with attention
  • understand how attention works
  • build the NMT model from scratch using Trax
  • generate translations using greedy and Minimum Bayes Risk (MBR) decoding

The Posts

This will be broken up into the following posts.

First - a look at the data.

Stack Semantics

Stack Semantics in Trax

This will help in understanding how to use layers like Select and Residual which operate on elements in the stack. If you've taken a computer science class before, you will recall that a stack is a data structure that follows the Last In, First Out (LIFO) principle. That is, whatever is the latest element that is pushed into the stack will also be the first one to be popped out. If you're not yet familiar with stacks, then you may find this short tutorial useful. In a nutshell, all you really need to remember is it puts elements one on top of the other. You should be aware of what is on top of the stack to know which element you will be popping.

Imports

# pypi
import numpy
from trax import fastmath, layers, shapes

Middle

The Serial Combinator is Stack Oriented.

To understand how stack-orientation works in Trax, most times one will be using the Serial layer. We will define two simple Function layers:

  1. Addition
  2. Multiplication

Suppose we want to make the simple calculation \((3 + 4) \times 15 + 3\). We'll use Serial to perform the calculations in the following order 3 4 add 15 mul 3 add. The steps of the calculation are shown in the table below.

Stack Operations Stack
Push(4) 4
Push(3) 4 3
Push(Add Pop() Pop()) 7
Push(15) 7 15
Push(Mul Pop() Pop()) 105
Push(3) 105 3
Push(Add() Pop() Pop()) 108

The first column shows the operations made on the stack and the second column is what's on the stack. Moreover, the rightmost element in the second column represents the top of the stack (e.g. in the second row, Push(3) pushes 3 = on top of the stack and =4 is now under it).

After finishing the steps the stack contains 108 which is the answer to our simple computation.

From this, the following can be concluded: a stack-based layer has only one way to handle data, by taking one piece of data from atop the stack, called popping, and putting data back atop the stack, called pushing. Any expression that can be written conventionally, can be written this way and thus will be amenable to being interpreted by a stack-oriented layer like Serial.

Defining addition

We're going to define a trax function (FN) for addition.

def Addition():
    layer_name = "Addition" 

    def func(x, y):
        return x + y

    return layers.Fn(layer_name, func)

Test it out.

add = Addition()
print(type(add))
<class 'trax.layers.base.PureLayer'>
print("name :", add.name)
print("expected inputs :", add.n_in)
print("promised outputs :", add.n_out)
name : Addition
expected inputs : 2
promised outputs : 1
x = numpy.array([3])
y = numpy.array([4])

print(f"{x} + {y} = {add((x, y))}")
[3] + [4] = [7]

Defining multiplication

def Multiplication():
    layer_name = "Multiplication"

    def func(x, y):
        return x * y

    return layers.Fn(layer_name, func)

Test it out.

mul = Multiplication()

The properties.

print("name :", mul.name)
print("expected inputs :", mul.n_in)
print("promised outputs :", mul.n_out, "\n")
name : Multiplication
expected inputs : 2
promised outputs : 1 

Some Inputs.

x = numpy.array([7])
y = numpy.array([15])
print("x :", x)
print("y :", y)
x : [7]
y : [15]

The Output

z = mul((x, y))
print(f"{x} * {y} = {mul((x, y))}")
[7] * [15] = [105]

Implementing the computations using the Serial combinator

serial = layers.Serial(
    Addition(), Multiplication(), Addition()
)
inputs = (numpy.array([3]), numpy.array([4]), numpy.array([15]), numpy.array([3]))

serial.init(shapes.signature(inputs))
print(serial, "\n")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out, "\n")
Serial_in4[
  Addition_in2
  Multiplication_in2
  Addition_in2
] 

name : Serial
sublayers : [Addition_in2, Multiplication_in2, Addition_in2]
expected inputs : 4
promised outputs : 1 
print(f"{inputs} -> {serial(inputs)}")
(array([3]), array([4]), array([15]), array([3])) -> [108]

The example with the two simple adition and multiplication functions that where coded together with the serial combinator show how stack semantics work in Trax.

The tl.Select combinator in the context of the Serial combinator

Having understood how stack semantics work in Trax, we will demonstrate how the tl.Select combinator works.

First example of tl.Select

Suppose we want to make the simple calculation \((3 + 4) \times 3 + 4\). We can use Select to perform the calculations in the following manner:

  1. input 3 4
  2. tl.Select([0, 1, 0, 1])
  3. add
  4. mul
  5. add.

The tl.Select requires a list or tuple of 0-based indices to select elements relative to the top of the stack. For our example, the top of the stack is 3 (which is at index 0) then 4 (index 1) and we us Select to copy the top two elements of the stack and then push all four elements back onto the stack which after the command executes will now contain 3 4 3 4. The steps of the calculation for our example are shown in the table below. As in the previous table each column shows the contents of the stack and the outputs after the operations are carried out.

Stack Operations Stack
Push(4) 4
Push(3) 4 3
Push(Select([0, 1, 0, 1])) 4 3 4 3
Push(Add Pop() Pop()) 4 3 7
Push(Mul Pop() Pop()) 4 21
Push(Add Pop() Pop()) 25

After processing all the inputs the stack contains 25 which is the result of the calculations.

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]),
    Addition(),
    Multiplication(),
    Addition()
)

Now we'll create the input.

x = (numpy.array([3]), numpy.array([4]))
serial.init(shapes.signature(x))
print(serial, "\n")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out, "\n")
Serial_in2[
  Select[0,1,0,1]_in2_out4
  Addition_in2
  Multiplication_in2
  Addition_in2
] 

name : Serial
sublayers : [Select[0,1,0,1]_in2_out4, Addition_in2, Multiplication_in2, Addition_in2]
expected inputs : 2
promised outputs : 1 
print(f"{x} -> {serial(x)}")
(array([3]), array([4])) -> [25]

Select Makes It More Like a Collection

Note that since you are passing in indices to Select, you aren't really using it like a stack, even if behind the scenes it's using push and pop.

serial = layers.Serial(
    layers.Select([2, 1, 1, 2]),
    Addition(),
    Multiplication(),
    Addition()
)

x = (numpy.array([3]), numpy.array([4]), numpy.array([5]))
serial.init(shapes.signature(x))

print(f"{x} -> {serial(x)}")
(array([3]), array([4]), array([5])) -> [41]
print((5 + 4) * 4 + 5)
41

Another example of tl.Select

Suppose we want to make the simple calculation \((3 + 4) \times 4\). We can use Select to perform the calculations in the following manner:

  1. 4
  2. 3
  3. tl.Select([0,1,0,1])
  4. add
  5. tl.Select([0], n_in=2)
  6. mul

The example is a bit contrived but it demonstrates the flexibility of the command. The second tl.Select pops two elements (specified in n_in) from the stack starting from index 0 (i.e. top of the stack). This means that 7 and 3 = will be popped out because ~n_in = 2~) but only =7 is placed back on top because it only selects [0]. As in the previous table each column shows the contents of the stack and the outputs after the operations are carried out.

Stack Operations Outputs
Push(4) 4
Push(3) 4 3
Push(select([0, 1, 0, 1])) 4 3 4 3
Push(Add Pop() Pop()) 4 3 7
Push(select([0], n_in=2)) 7
Push(Mul Pop() Pop()) 28

After processing all the inputs the stack contains 28 which is the answer we get above.

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]),
    Addition(),
    layers.Select([0], n_in=2),
    Multiplication()
)
inputs = (numpy.array([3]), numpy.array([4]))
serial.init(shapes.signature(inputs))
print(serial, "\n")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
Serial_in2[
  Select[0,1,0,1]_in2_out4
  Addition_in2
  Select[0]_in2
  Multiplication_in2
] 

name : Serial
sublayers : [Select[0,1,0,1]_in2_out4, Addition_in2, Select[0]_in2, Multiplication_in2]
expected inputs : 2
promised outputs : 1
print(f"{inputs} -> {serial(inputs)}")
(array([3]), array([4])) -> [28]

In summary, what Select does in this example is make a copy of the inputs in order to be used further along in the stack of operations.

The tl.Residual combinator in the context of the Serial combinator

tl.Residual

Residual networks (that link is to a research paper, this is wikipedia)are frequently used to make deep models easier to train. Trax already has a built in layer for this. The Residual layer computes the element-wise sum of the stack-top input with the output of the layer series. Let's first see how it is used in the code below:

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]),
    layers.Residual(Addition())
)

print(serial, "\n")
print("name :", serial.name)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
Serial_in2_out3[
  Select[0,1,0,1]_in2_out4
  Serial_in2[
    Branch_in2_out2[
      None
      Addition_in2
    ]
    Add_in2
  ]
] 

name : Serial
expected inputs : 2
promised outputs : 3

Here, we use the Serial combinator to define our model. The inputs first goes through a Select layer, followed by a Residual layer which passes the Fn: Addition() layer as an argument. What this means is the Residual layer will take the stack top input at that point and add it to the output of the Fn: Addition() layer. You can picture it like the diagram the below, where x1 and x2 are the inputs to the model:

Now, let's try running our model with some sample inputs and see the result:

x1 = numpy.array([3])
x2 = numpy.array([4])

print(f"{x1} + {x2} -> {serial((x1, x2))}")
[3] + [4] -> (array([10]), array([3]), array([4]))

As you can see, the Residual layer remembers the stack top input (i.e. 3) and adds it to the result of the Fn: Addition() layer (i.e. 3 + 4 = 7). The output of Residual(Addition() is then 3 + 7 = 10 and is pushed onto the stack.

On a different note, you'll notice that the Select layer has 4 outputs but the Fn: Addition() layer only pops 2 inputs from the stack. This means the duplicate inputs (i.e. the 2 rightmost arrows of the Select outputs in the figure above) remain in the stack. This is why you still see it in the output of our simple serial network (i.e. array([3]), array([4])). This is useful if you want to use these duplicate inputs in another layer further down the network.

Modifying the network

To strengthen your understanding, you can modify the network above and examine the outputs you get. For example, you can pass the Fn: Multiplication() layer instead in the Residual block:

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]), 
    layers.Residual(Multiplication())
)

print(serial, "\n")
print("name :", serial.name)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
Serial_in2_out3[
  Select[0,1,0,1]_in2_out4
  Serial_in2[
    Branch_in2_out2[
      None
      Multiplication_in2
    ]
    Add_in2
  ]
] 

name : Serial
expected inputs : 2
promised outputs : 3

This means you'll have a different output that will be added to the stack top input saved by the Residual block. The diagram becomes like this:

And you'll get 3 + (3 * 4) = 15 as output of the Residual block:

x1 = numpy.array([3])
x2 = numpy.array([4])

y = serial((x1, x2))
print(f"{x1} * {x2} -> {serial((x1, x2))}")
[3] * [4] -> (array([15]), array([3]), array([4]))