Siamese Networks: Training the Model
Table of Contents
Beginning
Now we are going to train the Siamese Network Model model. As usual, we have to define the cost function and the optimizer. We also have to feed in the built model. Before, going into the training, we will use a special data set up. We will define the inputs using the data generator we built above. The lambda function acts as a seed to remember the last batch that was given. Run the cell below to get the question pairs inputs.
Imports
# python
from collections import namedtuple
from functools import partial
from pathlib import Path
from tempfile import TemporaryFile
import sys
# pypi
from holoviews import opts
import holoviews
import hvplot.pandas
import jax
import numpy
import pandas
import trax
# this project
from neurotic.nlp.siamese_networks import (
DataGenerator,
DataLoader,
SiameseModel,
TOKENS,
triplet_loss_layer,
)
from graeae import Timer, EmbedHoloviews
Set Up
The Timer And Plotting
TIMER = Timer()
slug = "siamese-networks-training-the-model"
Embed = partial(EmbedHoloviews, folder_path=f"files/posts/nlp/{slug}")
Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
width=900,
height=750,
fontscale=2,
tan="#ddb377",
blue="#4687b7",
red="#ce7b6d",
)
The Data
loader = DataLoader()
data = loader.data
The Data generator
batch_size = 256
train_generator = DataGenerator(data.train.question_one, data.train.question_two,
batch_size=batch_size)
validation_generator = DataGenerator(data.validate.question_one,
data.validate.question_two,
batch_size=batch_size)
print(f"training question 1 rows: {len(data.train.question_one):,}")
print(f"validation question 1 rows: {len(data.validate.question_one):,}")
training question 1 rows: 89,179 validation question 1 rows: 22,295
Middle
Training the Model
We will now write a function that takes in the model and trains it. To train the model we have to decide how many times to iterate over the entire data set; each iteration is defined as an epoch
. For each epoch, you have to go over all the data, using the training iterator.
- Create
TrainTask
andEvalTask
- Create the training loop
trax.supervised.training.Loop
- Pass in the following depending on the context (train_task or eval_task):
labeled_data=generator
metrics
[TripletLoss()]=,loss_layer=TripletLoss()
optimizer=trax.optimizers.Adam
with learning rate of 0.01lr_schedule=lr_schedule
,output_dir=output_dir
We will be using the triplet loss function with Adam optimizer. Please read the trax Adam documentation to get a full understanding.
This function should return a training.Loop
object. To read more about this check the training.Loop documentation.
lr_schedule = trax.lr.warmup_and_rsqrt_decay(400, 0.01)
def train_model(Siamese, TripletLoss, lr_schedule, train_generator=train_generator, val_generator=validation_generator, output_dir="~/models/siamese_networks/",
steps_per_checkpoint=100):
"""Training the Siamese Model
Args:
Siamese (function): Function that returns the Siamese model.
TripletLoss (function): Function that defines the TripletLoss loss function.
lr_schedule (function): Trax multifactor schedule function.
train_generator (generator, optional): Training generator. Defaults to train_generator.
val_generator (generator, optional): Validation generator. Defaults to val_generator.
output_dir (str, optional): Path to save model to. Defaults to 'model/'.
Returns:
trax.supervised.training.Loop: Training loop for the model.
"""
output_dir = Path(output_dir).expanduser()
### START CODE HERE (Replace instances of 'None' with your code) ###
train_task = trax.supervised.training.TrainTask(
labeled_data=train_generator, # Use generator (train)
loss_layer=TripletLoss(), # Use triplet loss. Don't forget to instantiate this object
optimizer=trax.optimizers.Adam(0.01), # Don't forget to add the learning rate parameter
lr_schedule=lr_schedule, # Use Trax multifactor schedule function
n_steps_per_checkpoint=steps_per_checkpoint,
)
eval_task = trax.supervised.training.EvalTask(
labeled_data=val_generator, # Use generator (val)
metrics=[TripletLoss()], # Use triplet loss. Don't forget to instantiate this object
)
### END CODE HERE ###
training_loop = trax.supervised.training.Loop(Siamese,
[train_task],
eval_tasks=[eval_task],
output_dir=output_dir)
return training_loop
Training
Trial Two
Note: I re-ran this next code block so it's actually the second run.
train_steps = 2000
siamese = SiameseModel(len(loader.vocabulary))
training_loop = train_model(siamese.model, triplet_loss_layer, lr_schedule, steps_per_checkpoint=5)
real_stdout = sys.stdout
TIMER.emit = False
TIMER.start()
with TemporaryFile("w") as temp_file:
sys.stdout = temp_file
training_loop.run(train_steps)
TIMER.stop()
sys.stdout = real_stdout
print(f"{TIMER.ended - TIMER.started}")
0:19:46.056057
for mode in training_loop.history.modes:
print(mode)
print(training_loop.history.metrics_for_mode(mode))
eval ['metrics/TripletLoss'] train ['metrics/TripletLoss', 'training/gradients_l2', 'training/learning_rate', 'training/loss', 'training/steps per second', 'training/weights_l2']
- Plotting the Metrics
Note: As of February 2021, the version of trax on pypi doesn't have a history attribute - to get it you have to install the code from the github repository.
frame = pandas.DataFrame(training_loop.history.get("eval", "metrics/TripletLoss"), columns="Batch TripletLoss".split()) minimum = frame.loc[frame.TripletLoss.idxmin()] vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red)) hline = holoviews.HLine(minimum.TripletLoss).opts(opts.HLine(color=PLOT.red)) line = frame.hvplot(x="Batch", y="TripletLoss").opts(opts.Curve(color=PLOT.blue)) plot = (line * hline * vline).opts( width=PLOT.width, height=PLOT.height, title="Evaluation Batch Triplet Loss", ) output = Embed(plot=plot, file_name="evaluation_triplet_loss")()
print(output)
It looks the loss is stabilizing. If it doesn't perform well I'll re-train it.
Trial Three
Let's see if the continues going down.
train_steps = 2000
siamese = SiameseModel(len(loader.vocabulary))
training_loop = train_model(siamese.model, triplet_loss_layer, lr_schedule, steps_per_checkpoint=5)
real_stdout = sys.stdout
TIMER.emit = False
TIMER.start()
with TemporaryFile("w") as temp_file:
sys.stdout = temp_file
training_loop.run(train_steps)
TIMER.stop()
sys.stdout = real_stdout
print(f"{TIMER.ended - TIMER.started}")
0:17:41.167719
- Plotting the Metrics
frame = pandas.DataFrame( training_loop.history.get("eval", "metrics/TripletLoss"), columns="Batch TripletLoss".split()) minimum = frame.loc[frame.TripletLoss.idxmin()] vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red)) hline = holoviews.HLine(minimum.TripletLoss).opts(opts.HLine(color=PLOT.red)) line = frame.hvplot(x="Batch", y="TripletLoss").opts(opts.Curve(color=PLOT.blue)) plot = (line * hline * vline).opts( width=PLOT.width, height=PLOT.height, title="Evaluation Batch Triplet Loss (Third Run)", ) output = Embed(plot=plot, file_name="evaluation_triplet_loss_third")()
print(output)
It looks like it stopped improving. Probably time to stop.