Neural Machine Translation: Training the Model

Training Our Model

In the previous post we defined our model for machine translation. In this post we'll train the model on our data.

Doing supervised training in Trax is pretty straightforward (short example here). We will be instantiating three classes for this: TrainTask, EvalTask, and Loop. Let's take a closer look at each of these in the sections below.

Imports

# python
from collections import namedtuple
from contextlib import redirect_stdout
from functools import partial
from pathlib import Path

import sys

# pypi
from holoviews import opts
from trax import layers, optimizers
from trax.supervised import lr_schedules, training

import holoviews
import hvplot.pandas
import pandas

# this project
from neurotic.nlp.machine_translation import DataGenerator, NMTAttn

# related
from graeae import EmbedHoloviews, Timer

Set Up

train_batch_stream = DataGenerator().batch_generator
eval_batch_stream = DataGenerator(training=False).batch_generator
SLUG = "neural-machine-translation-training-the-model"
Embed = partial(EmbedHoloviews, folder_path=f"files/posts/nlp/{SLUG}")

Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )
TIMER = Timer()

Training

TrainTask

The TrainTask class allows us to define the labeled data to use for training and the feedback mechanisms to compute the loss and update the weights.

train_task = training.TrainTask(

    # use the train batch stream as labeled data
    labeled_data = train_batch_stream,

    # use the cross entropy loss
    loss_layer = layers.WeightedCategoryCrossEntropy(),

    # use the Adam optimizer with learning rate of 0.01
    optimizer = optimizers.Adam(0.01),

    # use the `trax.lr.warmup_and_rsqrt_decay` as the learning rate schedule
    # have 1000 warmup steps with a max value of 0.01
    lr_schedule = lr_schedules.warmup_and_rsqrt_decay(1000, 0.01),

    # have a checkpoint every 10 steps
    n_steps_per_checkpoint= 10,
)
def test_train_task(train_task):
    target = train_task
    success = 0
    fails = 0

    # Test the labeled data parameter
    try:
        strlabel = str(target._labeled_data)
        assert(strlabel.find("generator") and strlabel.find('add_loss_weights'))
        success += 1
    except:
        fails += 1
        print("Wrong labeled data parameter")

    # Test the cross entropy loss data parameter
    try:
        strlabel = str(target._loss_layer)
        assert(strlabel == "CrossEntropyLoss_in3")
        success += 1
    except:
        fails += 1
        print("Wrong loss functions. CrossEntropyLoss_in3 was expected")

     # Test the optimizer parameter
    try:
        assert(isinstance(target.optimizer, trax.optimizers.adam.Adam))
        success += 1
    except:
        fails += 1
        print("Wrong optimizer")

    # Test the schedule parameter
    try:
        assert(isinstance(target._lr_schedule,trax.supervised.lr_schedules._BodyAndTail))
        success += 1
    except:
        fails += 1
        print("Wrong learning rate schedule type")

    # Test the _n_steps_per_checkpoint parameter
    try:
        assert(target._n_steps_per_checkpoint==10)
        success += 1
    except:
        fails += 1
        print("Wrong checkpoint step frequency")

    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
    return
test_train_task(train_task)
Wrong loss functions. CrossEntropyLoss_in3 was expected
Wrong optimizer
Wrong learning rate schedule type
[92m 2  Tests passed
[91m 3  Tests failed

The code has changed a bit since the test was written so it won't pass without updates.

EvalTask

The EvalTask on the other hand allows us to see how the model is doing while training. For our application, we want it to report the cross entropy loss and accuracy.

eval_task = training.EvalTask(

    ## use the eval batch stream as labeled data
    labeled_data=eval_batch_stream,

    ## use the cross entropy loss and accuracy as metrics
    metrics=[layers.WeightedCategoryCrossEntropy(), layers.Accuracy()],
)

Loop

The Loop class defines the model we will train as well as the train and eval tasks to execute. Its run() method allows us to execute the training for a specified number of steps.

output_dir = Path("~/models/machine_translation/").expanduser()

Define the training loop.

training_loop = training.Loop(NMTAttn(mode='train'),
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)
train_steps = 1000

with TIMER, \
     open("/tmp/machine_translation_training.log", "w") as temp_file, \
     redirect_stdout(temp_file):
            training_loop.run(train_steps)
Started: 2021-03-09 18:31:58.844878
Ended: 2021-03-09 20:14:43.090358
Elapsed: 1:42:44.245480
frame = pandas.DataFrame(
    training_loop.history.get("eval", "metrics/WeightedCategoryCrossEntropy"),
    columns="Batch CrossEntropy".split())

minimum = frame.loc[frame.CrossEntropy.idxmin()]
vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
hline = holoviews.HLine(minimum.CrossEntropy).opts(opts.HLine(color=PLOT.red))
line = frame.hvplot(x="Batch", y="CrossEntropy").opts(opts.Curve(color=PLOT.blue))

plot = (line * hline * vline).opts(
    width=PLOT.width, height=PLOT.height,
    title="Evaluation Batch Cross Entropy Loss",
                                   )
output = Embed(plot=plot, file_name="evaluation_cross_entropy")()
print(output)

Figure Missing

frame = pandas.DataFrame(
    training_loop.history.get("eval", "metrics/Accuracy"),
    columns="Batch Accuracy".split())

minimum = frame.loc[frame.Accuracy.idxmin()]
vline = holoviews.VLine(minimum.Batch).opts(opts.VLine(color=PLOT.red))
hline = holoviews.HLine(minimum.Accuracy).opts(opts.HLine(color=PLOT.red))
line = frame.hvplot(x="Batch", y="Accuracy").opts(opts.Curve(color=PLOT.blue))

plot = (line * hline * vline).opts(
    width=PLOT.width, height=PLOT.height,
    title="Evaluation Batch Accuracy",
                                   )
output = Embed(plot=plot, file_name="evaluation_accuracy")()
print(output)

Figure Missing

It seems to be stuckā€¦

End

Now that we've trained the model in the next post we'll test our model to see how well it does. The overview post with links to all the posts in this series is here.

Raw