Neural Machine Translation: Training the Model
Training Our Model
In the previous post we defined our model for machine translation. In this post we'll train the model on our data.
Doing supervised training in Trax is pretty straightforward (short example here). We will be instantiating three classes for this: TrainTask
, EvalTask
, and Loop
. Let's take a closer look at each of these in the sections below.
Imports
# python
from collections import namedtuple
from contextlib import redirect_stdout
from functools import partial
from pathlib import Path
import sys
# pypi
from holoviews import opts
from trax import layers, optimizers
from trax.supervised import lr_schedules, training
import holoviews
import hvplot.pandas
import pandas
# this project
from neurotic.nlp.machine_translation import DataGenerator, NMTAttn
# related
from graeae import EmbedHoloviews, Timer
Set Up
train_batch_stream = DataGenerator().batch_generator
eval_batch_stream = DataGenerator(training=False).batch_generator
SLUG = "neural-machine-translation-training-the-model"
Embed = partial(EmbedHoloviews, folder_path=f"files/posts/nlp/{SLUG}")
Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
width=900,
height=750,
fontscale=2,
tan="#ddb377",
blue="#4687b7",
red="#ce7b6d",
)
TIMER = Timer()
Training
TrainTask
The TrainTask class allows us to define the labeled data to use for training and the feedback mechanisms to compute the loss and update the weights.
train_task = training.TrainTask(
# use the train batch stream as labeled data
labeled_data = train_batch_stream,
# use the cross entropy loss
loss_layer = layers.WeightedCategoryCrossEntropy(),
# use the Adam optimizer with learning rate of 0.01
optimizer = optimizers.Adam(0.01),
# use the `trax.lr.warmup_and_rsqrt_decay` as the learning rate schedule
# have 1000 warmup steps with a max value of 0.01
lr_schedule = lr_schedules.warmup_and_rsqrt_decay(1000, 0.01),
# have a checkpoint every 10 steps
n_steps_per_checkpoint= 10,
)
def test_train_task(train_task):
target = train_task
success = 0
fails = 0
# Test the labeled data parameter
try:
strlabel = str(target._labeled_data)
assert(strlabel.find("generator") and strlabel.find('add_loss_weights'))
success += 1
except:
fails += 1
print("Wrong labeled data parameter")
# Test the cross entropy loss data parameter
try:
strlabel = str(target._loss_layer)
assert(strlabel == "CrossEntropyLoss_in3")
success += 1
except:
fails += 1
print("Wrong loss functions. CrossEntropyLoss_in3 was expected")
# Test the optimizer parameter
try:
assert(isinstance(target.optimizer, trax.optimizers.adam.Adam))
success += 1
except:
fails += 1
print("Wrong optimizer")
# Test the schedule parameter
try:
assert(isinstance(target._lr_schedule,trax.supervised.lr_schedules._BodyAndTail))
success += 1
except:
fails += 1
print("Wrong learning rate schedule type")
# Test the _n_steps_per_checkpoint parameter
try:
assert(target._n_steps_per_checkpoint==10)
success += 1
except:
fails += 1
print("Wrong checkpoint step frequency")
if fails == 0:
print("\033[92m All tests passed")
else:
print('\033[92m', success," Tests passed")
print('\033[91m', fails, " Tests failed")
return
test_train_task(train_task)
Wrong loss functions. CrossEntropyLoss_in3 was expected Wrong optimizer Wrong learning rate schedule type [92m 2 Tests passed [91m 3 Tests failed
The code has changed a bit since the test was written so it won't pass without updates.
EvalTask
The EvalTask on the other hand allows us to see how the model is doing while training. For our application, we want it to report the cross entropy loss and accuracy.
eval_task = training.EvalTask(
## use the eval batch stream as labeled data
labeled_data=eval_batch_stream,
## use the cross entropy loss and accuracy as metrics
metrics=[layers.WeightedCategoryCrossEntropy(), layers.Accuracy()],
)
Loop
The Loop class defines the model we will train as well as the train and eval tasks to execute. Its run()
method allows us to execute the training for a specified number of steps.
output_dir = Path("~/models/machine_translation/").expanduser()
Define the training loop.
training_loop = training.Loop(NMTAttn(mode='train'),
train_task,
eval_tasks=[eval_task],
output_dir=output_dir)
train_steps = 1000
with TIMER, \
open("/tmp/machine_translation_training.log", "w") as temp_file, \
redirect_stdout(temp_file):
training_loop.run(train_steps)
Started: 2021-03-09 18:31:58.844878 Ended: 2021-03-09 20:14:43.090358 Elapsed: 1:42:44.245480
frame = pandas.DataFrame(
training_loop.history.get("eval", "metrics/WeightedCategoryCrossEntropy"),
columns="Batch CrossEntropy".split())
minimum = frame.loc[frame.CrossEntropy.idxmin()]
vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
hline = holoviews.HLine(minimum.CrossEntropy).opts(opts.HLine(color=PLOT.red))
line = frame.hvplot(x="Batch", y="CrossEntropy").opts(opts.Curve(color=PLOT.blue))
plot = (line * hline * vline).opts(
width=PLOT.width, height=PLOT.height,
title="Evaluation Batch Cross Entropy Loss",
)
output = Embed(plot=plot, file_name="evaluation_cross_entropy")()
print(output)
frame = pandas.DataFrame(
training_loop.history.get("eval", "metrics/Accuracy"),
columns="Batch Accuracy".split())
minimum = frame.loc[frame.Accuracy.idxmin()]
vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
hline = holoviews.HLine(minimum.Accuracy).opts(opts.HLine(color=PLOT.red))
line = frame.hvplot(x="Batch", y="Accuracy").opts(opts.Curve(color=PLOT.blue))
plot = (line * hline * vline).opts(
width=PLOT.width, height=PLOT.height,
title="Evaluation Batch Accuracy",
)
output = Embed(plot=plot, file_name="evaluation_accuracy")()
print(output)
It seems to be stuckā¦