Neural Machine Translation: Helper Functions
Table of Contents
Helper Functions
We will first implement a few functions that we will use later on. These will be for:
- the input encoder
- the pre-attention decoder
- preparation of the queries, keys, values, and mask.
Imports
# from pypi
from trax import layers
from trax.fastmath import numpy as fastmath_numpy
import trax
Helper functions
Input encoder
The input encoder runs on the input tokens, creates its embeddings, and feeds it to an LSTM network. This outputs the activations that will be the keys and values for attention. It is a Serial network which uses:
- tl.Embedding: Converts each token to its vector representation. In this case, it is the the size of the vocabulary by the dimension of the model:
tl.Embedding(vocab_size, d_model)
.vocab_size
is the number of entries in the given vocabulary.d_model
is the number of elements in the word embedding. - tl.LSTM: LSTM layer of size
d_model
. We want to be able to configure how many encoder layers we have so remember to create LSTM layers equal to the number of then_encoder_layers
parameter.
def input_encoder(input_vocab_size: int, d_model: int,
n_encoder_layers: int) -> layers.Serial:
""" Input encoder runs on the input sentence and creates
activations that will be the keys and values for attention.
Args:
input_vocab_size: vocab size of the input
d_model: depth of embedding (n_units in the LSTM cell)
n_encoder_layers: number of LSTM layers in the encoder
Returns:
tl.Serial: The input encoder
"""
input_encoder = layers.Serial(
layers.Embedding(input_vocab_size, d_model),
[layers.LSTM(d_model) for _ in range(n_encoder_layers)]
)
return input_encoder
def test_input_encoder_fn(input_encoder_fn):
target = input_encoder_fn
success = 0
fails = 0
input_vocab_size = 10
d_model = 2
n_encoder_layers = 6
encoder = target(input_vocab_size, d_model, n_encoder_layers)
lstms = "\n".join([f' LSTM_{d_model}'] * n_encoder_layers)
expected = f"Serial[\n Embedding_{input_vocab_size}_{d_model}\n{lstms}\n]"
proposed = str(encoder)
# Test all layers are in the expected sequence
try:
assert(proposed.replace(" ", "") == expected.replace(" ", ""))
success += 1
except:
fails += 1
print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)
# Test the output type
try:
assert(isinstance(encoder, trax.layers.combinators.Serial))
success += 1
# Test the number of layers
try:
# Test
assert len(encoder.sublayers) == (n_encoder_layers + 1)
success += 1
except:
fails += 1
print('The number of sublayers does not match %s <>' %len(encoder.sublayers), " %s" %(n_encoder_layers + 1))
except:
fails += 1
print("The enconder is not an object of ", trax.layers.combinators.Serial)
if fails == 0:
print("\033[92m All tests passed")
else:
print('\033[92m', success," Tests passed")
print('\033[91m', fails, " Tests failed")
test_input_encoder_fn(input_encoder)
[92m All tests passed
Pre-attention decoder
The pre-attention decoder runs on the targets and creates activations that are used as queries in attention. This is a Serial network which is composed of the following:
- tl.ShiftRight: This pads a token to the beginning of your target tokens (e.g.
[8, 34, 12]
shifted right is[0, 8, 34, 12]
). This will act like a start-of-sentence token that will be the first input to the decoder. During training, this shift also allows the target tokens to be passed as input to do teacher forcing. - tl.Embedding: Like in the previous function, this converts each token to its vector representation. In this case, it is the the size of the vocabulary by the dimension of the model:
tl.Embedding(vocab_size, d_model)
.vocab_size
is the number of entries in the given vocabulary.d_model
is the number of elements in the word embedding. - tl.LSTM: LSTM layer of size
d_model
.
def pre_attention_decoder(mode: str, target_vocab_size: int, d_model: int) -> layers.Serial:
""" Pre-attention decoder runs on the targets and creates
activations that are used as queries in attention.
Args:
mode: 'train' or 'eval'
target_vocab_size: vocab size of the target
d_model: depth of embedding (n_units in the LSTM cell)
Returns:
tl.Serial: The pre-attention decoder
"""
return layers.Serial(
layers.ShiftRight(mode=mode),
layers.Embedding(target_vocab_size, d_model),
layers.LSTM(d_model)
)
def test_pre_attention_decoder_fn(pre_attention_decoder_fn):
target = pre_attention_decoder_fn
success = 0
fails = 0
mode = 'train'
target_vocab_size = 10
d_model = 2
decoder = target(mode, target_vocab_size, d_model)
expected = f"Serial[\n ShiftRight(1)\n Embedding_{target_vocab_size}_{d_model}\n LSTM_{d_model}\n]"
proposed = str(decoder)
# Test all layers are in the expected sequence
try:
assert(proposed.replace(" ", "") == expected.replace(" ", ""))
success += 1
except:
fails += 1
print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)
# Test the output type
try:
assert(isinstance(decoder, trax.layers.combinators.Serial))
success += 1
# Test the number of layers
try:
# Test
assert len(decoder.sublayers) == 3
success += 1
except:
fails += 1
print('The number of sublayers does not match %s <>' %len(decoder.sublayers), " %s" %3)
except:
fails += 1
print("The enconder is not an object of ", trax.layers.combinators.Serial)
if fails == 0:
print("\033[92m All tests passed")
else:
print('\033[92m', success," Tests passed")
print('\033[91m', fails, " Tests failed")
They changed the behavior of the Fn
(or something in there) so that it always wraps the ShiftRight in a Serial layer, so it doesn't match the test anymore. Testing strings is kind of gimpy anyway…
It looks like they're using a decorator to check the shape which then wraps it in a Serial layer. See trax.layers.assert_shape.AssertFunction
test_pre_attention_decoder_fn(pre_attention_decoder)
Wrong model. Proposed: Serial[ Serial[ ShiftRight(1) ] Embedding_10_2 LSTM_2 ] Expected: Serial[ ShiftRight(1) Embedding_10_2 LSTM_2 ] [92m 2 Tests passed [91m 1 Tests failed
Preparing the attention input
This function will prepare the inputs to the attention layer. We want to take in the encoder and pre-attention decoder activations and assign it to the queries, keys, and values. In addition, another output here will be the mask to distinguish real tokens from padding tokens. This mask will be used internally by Trax when computing the softmax so padding tokens will not have an effect on the computated probabilities. From the data preparation steps in Section 1 of this assignment, you should know which tokens in the input correspond to padding.
def prepare_attention_input(encoder_activations: fastmath_numpy.array,
decoder_activations: fastmath_numpy.array,
inputs: fastmath_numpy.array) -> tuple:
"""Prepare queries, keys, values and mask for attention.
Args:
encoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the input encoder
decoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the pre-attention decoder
inputs fastnp.array(batch_size, padded_input_length): padded input tokens
Returns:
queries, keys, values and mask for attention.
"""
keys = encoder_activations
values = encoder_activations
queries = decoder_activations
mask = inputs != 0
mask = fastmath_numpy.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
mask += fastmath_numpy.zeros((1, 1, decoder_activations.shape[1], 1))
return queries, keys, values, mask
def test_prepare_attention_input(prepare_attention_input):
target = prepare_attention_input
success = 0
fails = 0
#This unit test consider a batch size = 2, number_of_tokens = 3 and embedding_size = 4
enc_act = fastmath_numpy.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
[[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 0]]])
dec_act = fastmath_numpy.array([[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0]],
[[2, 0, 2, 0], [0, 2, 0, 2], [0, 0, 0, 0]]])
inputs = fastmath_numpy.array([[1, 2, 3], [1, 4, 0]])
exp_mask = fastmath_numpy.array([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]],
[[[1., 1., 0.], [1., 1., 0.], [1., 1., 0.]]]])
exp_type = type(enc_act)
queries, keys, values, mask = target(enc_act, dec_act, inputs)
try:
assert(fastmath_numpy.allclose(queries, dec_act))
success += 1
except:
fails += 1
print("Queries does not match the decoder activations")
try:
assert(fastmath_numpy.allclose(keys, enc_act))
success += 1
except:
fails += 1
print("Keys does not match the encoder activations")
try:
assert(fastmath_numpy.allclose(values, enc_act))
success += 1
except:
fails += 1
print("Values does not match the encoder activations")
try:
assert(fastmath_numpy.allclose(mask, exp_mask))
success += 1
except:
fails += 1
print("Mask does not match expected tensor. \nExpected:\n%s" %exp_mask, "\nOutput:\n%s" %mask)
# Test the output type
try:
assert(isinstance(queries, exp_type))
assert(isinstance(keys, exp_type))
assert(isinstance(values, exp_type))
assert(isinstance(mask, exp_type))
success += 1
except:
fails += 1
print("One of the output object are not of type ", jax.interpreters.xla.DeviceArray)
if fails == 0:
print("\033[92m All tests passed")
else:
print('\033[92m', success," Tests passed")
print('\033[91m', fails, " Tests failed")
test_prepare_attention_input(prepare_attention_input)
[92m All tests passed