Siamese Networks: Training the Model

Beginning

Now we are going to train the Siamese Network Model model. As usual, we have to define the cost function and the optimizer. We also have to feed in the built model. Before, going into the training, we will use a special data set up. We will define the inputs using the data generator we built above. The lambda function acts as a seed to remember the last batch that was given. Run the cell below to get the question pairs inputs.

Imports

# python
from collections import namedtuple
from functools import partial
from pathlib import Path
from tempfile import TemporaryFile

import sys

# pypi
from holoviews import opts

import holoviews
import hvplot.pandas
import jax
import numpy
import pandas
import trax

# this project
from neurotic.nlp.siamese_networks import (
    DataGenerator,
    DataLoader,
    SiameseModel,
    TOKENS,
    triplet_loss_layer,
)

from graeae import Timer, EmbedHoloviews

Set Up

The Timer And Plotting

TIMER = Timer()
slug = "siamese-networks-training-the-model"
Embed = partial(EmbedHoloviews, folder_path=f"files/posts/nlp/{slug}")

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )

The Data

loader = DataLoader()

data = loader.data

The Data generator

batch_size = 256
train_generator = DataGenerator(data.train.question_one, data.train.question_two,
                                batch_size=batch_size)
validation_generator = DataGenerator(data.validate.question_one,
                                     data.validate.question_two,
                                     batch_size=batch_size)
print(f"training question 1 rows: {len(data.train.question_one):,}")
print(f"validation question 1 rows: {len(data.validate.question_one):,}")
training question 1 rows: 89,179
validation question 1 rows: 22,295

Middle

Training the Model

We will now write a function that takes in the model and trains it. To train the model we have to decide how many times to iterate over the entire data set; each iteration is defined as an epoch. For each epoch, you have to go over all the data, using the training iterator.

  • Create TrainTask and EvalTask
  • Create the training loop trax.supervised.training.Loop
  • Pass in the following depending on the context (train_task or eval_task):
    • labeled_data=generator
    • metrics[TripletLoss()]=,
    • loss_layer=TripletLoss()
    • optimizer=trax.optimizers.Adam with learning rate of 0.01
    • lr_schedule=lr_schedule,
    • output_dir=output_dir

We will be using the triplet loss function with Adam optimizer. Please read the trax Adam documentation to get a full understanding.

This function should return a training.Loop object. To read more about this check the training.Loop documentation.

lr_schedule = trax.lr.warmup_and_rsqrt_decay(400, 0.01)
def train_model(Siamese, TripletLoss, lr_schedule, train_generator=train_generator, val_generator=validation_generator, output_dir="~/models/siamese_networks/",
                steps_per_checkpoint=100):
    """Training the Siamese Model

    Args:
       Siamese (function): Function that returns the Siamese model.
       TripletLoss (function): Function that defines the TripletLoss loss function.
       lr_schedule (function): Trax multifactor schedule function.
       train_generator (generator, optional): Training generator. Defaults to train_generator.
       val_generator (generator, optional): Validation generator. Defaults to val_generator.
       output_dir (str, optional): Path to save model to. Defaults to 'model/'.

    Returns:
       trax.supervised.training.Loop: Training loop for the model.
    """
    output_dir = Path(output_dir).expanduser()

    ### START CODE HERE (Replace instances of 'None' with your code) ###

    train_task = trax.supervised.training.TrainTask(
        labeled_data=train_generator,       # Use generator (train)
        loss_layer=TripletLoss(),         # Use triplet loss. Don't forget to instantiate this object
        optimizer=trax.optimizers.Adam(0.01),          # Don't forget to add the learning rate parameter
        lr_schedule=lr_schedule, # Use Trax multifactor schedule function
        n_steps_per_checkpoint=steps_per_checkpoint,
    )

    eval_task = trax.supervised.training.EvalTask(
        labeled_data=val_generator,       # Use generator (val)
        metrics=[TripletLoss()],          # Use triplet loss. Don't forget to instantiate this object
    )

    ### END CODE HERE ###

    training_loop = trax.supervised.training.Loop(Siamese,
                                                  [train_task],
                                                  eval_tasks=[eval_task],
                                                  output_dir=output_dir)

    return training_loop

Training

Trial Two

Note: I re-ran this next code block so it's actually the second run.

train_steps = 2000
siamese = SiameseModel(len(loader.vocabulary))
training_loop = train_model(siamese.model, triplet_loss_layer, lr_schedule, steps_per_checkpoint=5)

real_stdout = sys.stdout

TIMER.emit = False
TIMER.start()
with TemporaryFile("w") as temp_file:
    sys.stdout = temp_file
    training_loop.run(train_steps)
TIMER.stop()
sys.stdout = real_stdout
print(f"{TIMER.ended - TIMER.started}")
0:19:46.056057
for mode in training_loop.history.modes:
    print(mode)
    print(training_loop.history.metrics_for_mode(mode))
eval
['metrics/TripletLoss']
train
['metrics/TripletLoss', 'training/gradients_l2', 'training/learning_rate', 'training/loss', 'training/steps per second', 'training/weights_l2']
  • Plotting the Metrics

    Note: As of February 2021, the version of trax on pypi doesn't have a history attribute - to get it you have to install the code from the github repository.

    frame = pandas.DataFrame(training_loop.history.get("eval", "metrics/TripletLoss"), columns="Batch TripletLoss".split())
    
    minimum = frame.loc[frame.TripletLoss.idxmin()]
    vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
    hline = holoviews.HLine(minimum.TripletLoss).opts(opts.HLine(color=PLOT.red))
    line = frame.hvplot(x="Batch", y="TripletLoss").opts(opts.Curve(color=PLOT.blue))
    
    plot = (line * hline * vline).opts(
        width=PLOT.width, height=PLOT.height,
        title="Evaluation Batch Triplet Loss",
                                       )
    output = Embed(plot=plot, file_name="evaluation_triplet_loss")()
    
    print(output)
    

    Figure Missing

    It looks the loss is stabilizing. If it doesn't perform well I'll re-train it.

Trial Three

Let's see if the continues going down.

train_steps = 2000
siamese = SiameseModel(len(loader.vocabulary))
training_loop = train_model(siamese.model, triplet_loss_layer, lr_schedule, steps_per_checkpoint=5)

real_stdout = sys.stdout

TIMER.emit = False
TIMER.start()
with TemporaryFile("w") as temp_file:
    sys.stdout = temp_file
    training_loop.run(train_steps)
TIMER.stop()
sys.stdout = real_stdout
print(f"{TIMER.ended - TIMER.started}")
0:17:41.167719
  • Plotting the Metrics
    frame = pandas.DataFrame(
        training_loop.history.get("eval", "metrics/TripletLoss"),
        columns="Batch TripletLoss".split())
    
    minimum = frame.loc[frame.TripletLoss.idxmin()]
    vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
    hline = holoviews.HLine(minimum.TripletLoss).opts(opts.HLine(color=PLOT.red))
    line = frame.hvplot(x="Batch", y="TripletLoss").opts(opts.Curve(color=PLOT.blue))
    
    plot = (line * hline * vline).opts(
        width=PLOT.width, height=PLOT.height,
        title="Evaluation Batch Triplet Loss (Third Run)",
                                       )
    output = Embed(plot=plot, file_name="evaluation_triplet_loss_third")()
    
    print(output)
    

    Figure Missing

    It looks like it stopped improving. Probably time to stop.