Neural Machine Translation: Helper Functions

Helper Functions

We will first implement a few functions that we will use later on. These will be for:

  • the input encoder
  • the pre-attention decoder
  • preparation of the queries, keys, values, and mask.

Imports

# from pypi
from trax import layers
from trax.fastmath import numpy as fastmath_numpy

import trax

Helper functions

Input encoder

The input encoder runs on the input tokens, creates its embeddings, and feeds it to an LSTM network. This outputs the activations that will be the keys and values for attention. It is a Serial network which uses:

  • tl.Embedding: Converts each token to its vector representation. In this case, it is the the size of the vocabulary by the dimension of the model: tl.Embedding(vocab_size, d_model). vocab_size is the number of entries in the given vocabulary. d_model is the number of elements in the word embedding.
  • tl.LSTM: LSTM layer of size d_model. We want to be able to configure how many encoder layers we have so remember to create LSTM layers equal to the number of the n_encoder_layers parameter.
def input_encoder(input_vocab_size: int, d_model: int,
                     n_encoder_layers: int) -> layers.Serial:
    """ Input encoder runs on the input sentence and creates
    activations that will be the keys and values for attention.

    Args:
       input_vocab_size: vocab size of the input
       d_model:  depth of embedding (n_units in the LSTM cell)
       n_encoder_layers: number of LSTM layers in the encoder

    Returns:
       tl.Serial: The input encoder
    """
    input_encoder = layers.Serial( 
        layers.Embedding(input_vocab_size, d_model),
        [layers.LSTM(d_model) for _ in range(n_encoder_layers)]
    )
    return input_encoder
def test_input_encoder_fn(input_encoder_fn):
    target = input_encoder_fn
    success = 0
    fails = 0

    input_vocab_size = 10
    d_model = 2
    n_encoder_layers = 6

    encoder = target(input_vocab_size, d_model, n_encoder_layers)

    lstms = "\n".join([f'  LSTM_{d_model}'] * n_encoder_layers)

    expected = f"Serial[\n  Embedding_{input_vocab_size}_{d_model}\n{lstms}\n]"

    proposed = str(encoder)

    # Test all layers are in the expected sequence
    try:
        assert(proposed.replace(" ", "") == expected.replace(" ", ""))
        success += 1
    except:
        fails += 1
        print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)

    # Test the output type
    try:
        assert(isinstance(encoder, trax.layers.combinators.Serial))
        success += 1
        # Test the number of layers
        try:
            # Test 
            assert len(encoder.sublayers) == (n_encoder_layers + 1)
            success += 1
        except:
            fails += 1
            print('The number of sublayers does not match %s <>' %len(encoder.sublayers), " %s" %(n_encoder_layers + 1))
    except:
        fails += 1
        print("The enconder is not an object of ", trax.layers.combinators.Serial)


    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
test_input_encoder_fn(input_encoder)
[92m All tests passed

Pre-attention decoder

The pre-attention decoder runs on the targets and creates activations that are used as queries in attention. This is a Serial network which is composed of the following:

  • tl.ShiftRight: This pads a token to the beginning of your target tokens (e.g. [8, 34, 12] shifted right is [0, 8, 34, 12]). This will act like a start-of-sentence token that will be the first input to the decoder. During training, this shift also allows the target tokens to be passed as input to do teacher forcing.
  • tl.Embedding: Like in the previous function, this converts each token to its vector representation. In this case, it is the the size of the vocabulary by the dimension of the model: tl.Embedding(vocab_size, d_model). vocab_size is the number of entries in the given vocabulary. d_model is the number of elements in the word embedding.
  • tl.LSTM: LSTM layer of size d_model.
def pre_attention_decoder(mode: str, target_vocab_size: int, d_model: int) -> layers.Serial:
    """ Pre-attention decoder runs on the targets and creates
    activations that are used as queries in attention.

    Args:
       mode: 'train' or 'eval'
       target_vocab_size: vocab size of the target
       d_model:  depth of embedding (n_units in the LSTM cell)
    Returns:
       tl.Serial: The pre-attention decoder
    """
    return layers.Serial(
        layers.ShiftRight(mode=mode),
        layers.Embedding(target_vocab_size, d_model),
        layers.LSTM(d_model)
    )
def test_pre_attention_decoder_fn(pre_attention_decoder_fn):
    target = pre_attention_decoder_fn
    success = 0
    fails = 0

    mode = 'train'
    target_vocab_size = 10
    d_model = 2

    decoder = target(mode, target_vocab_size, d_model)

    expected = f"Serial[\n  ShiftRight(1)\n  Embedding_{target_vocab_size}_{d_model}\n  LSTM_{d_model}\n]"

    proposed = str(decoder)

    # Test all layers are in the expected sequence
    try:
        assert(proposed.replace(" ", "") == expected.replace(" ", ""))
        success += 1
    except:
        fails += 1
        print("Wrong model. \nProposed:\n%s" %proposed, "\nExpected:\n%s" %expected)

    # Test the output type
    try:
        assert(isinstance(decoder, trax.layers.combinators.Serial))
        success += 1
        # Test the number of layers
        try:
            # Test 
            assert len(decoder.sublayers) == 3
            success += 1
        except:
            fails += 1
            print('The number of sublayers does not match %s <>' %len(decoder.sublayers), " %s" %3)
    except:
        fails += 1
        print("The enconder is not an object of ", trax.layers.combinators.Serial)


    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")

They changed the behavior of the Fn (or something in there) so that it always wraps the ShiftRight in a Serial layer, so it doesn't match the test anymore. Testing strings is kind of gimpy anyway…

It looks like they're using a decorator to check the shape which then wraps it in a Serial layer. See trax.layers.assert_shape.AssertFunction

test_pre_attention_decoder_fn(pre_attention_decoder)
Wrong model. 
Proposed:
Serial[
  Serial[
    ShiftRight(1)
  ]
  Embedding_10_2
  LSTM_2
] 
Expected:
Serial[
  ShiftRight(1)
  Embedding_10_2
  LSTM_2
]
[92m 2  Tests passed
[91m 1  Tests failed

Preparing the attention input

This function will prepare the inputs to the attention layer. We want to take in the encoder and pre-attention decoder activations and assign it to the queries, keys, and values. In addition, another output here will be the mask to distinguish real tokens from padding tokens. This mask will be used internally by Trax when computing the softmax so padding tokens will not have an effect on the computated probabilities. From the data preparation steps in Section 1 of this assignment, you should know which tokens in the input correspond to padding.

def prepare_attention_input(encoder_activations: fastmath_numpy.array,
                            decoder_activations: fastmath_numpy.array,
                            inputs: fastmath_numpy.array) -> tuple:
    """Prepare queries, keys, values and mask for attention.

    Args:
       encoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the input encoder
       decoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the pre-attention decoder
       inputs fastnp.array(batch_size, padded_input_length): padded input tokens

    Returns:
       queries, keys, values and mask for attention.
    """
    keys = encoder_activations
    values = encoder_activations
    queries = decoder_activations    
    mask = inputs != 0

    mask = fastmath_numpy.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    mask += fastmath_numpy.zeros((1, 1, decoder_activations.shape[1], 1))
    return queries, keys, values, mask
def test_prepare_attention_input(prepare_attention_input):
    target = prepare_attention_input
    success = 0
    fails = 0

    #This unit test consider a batch size = 2, number_of_tokens = 3 and embedding_size = 4

    enc_act = fastmath_numpy.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
               [[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 0]]])
    dec_act = fastmath_numpy.array([[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0]], 
               [[2, 0, 2, 0], [0, 2, 0, 2], [0, 0, 0, 0]]])
    inputs =  fastmath_numpy.array([[1, 2, 3], [1, 4, 0]])

    exp_mask = fastmath_numpy.array([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]], 
                             [[[1., 1., 0.], [1., 1., 0.], [1., 1., 0.]]]])

    exp_type = type(enc_act)

    queries, keys, values, mask = target(enc_act, dec_act, inputs)

    try:
        assert(fastmath_numpy.allclose(queries, dec_act))
        success += 1
    except:
        fails += 1
        print("Queries does not match the decoder activations")
    try:
        assert(fastmath_numpy.allclose(keys, enc_act))
        success += 1
    except:
        fails += 1
        print("Keys does not match the encoder activations")
    try:
        assert(fastmath_numpy.allclose(values, enc_act))
        success += 1
    except:
        fails += 1
        print("Values does not match the encoder activations")
    try:
        assert(fastmath_numpy.allclose(mask, exp_mask))
        success += 1
    except:
        fails += 1
        print("Mask does not match expected tensor. \nExpected:\n%s" %exp_mask, "\nOutput:\n%s" %mask)

    # Test the output type
    try:
        assert(isinstance(queries, exp_type))
        assert(isinstance(keys, exp_type))
        assert(isinstance(values, exp_type))
        assert(isinstance(mask, exp_type))
        success += 1
    except:
        fails += 1
        print("One of the output object are not of type ", jax.interpreters.xla.DeviceArray)

    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
test_prepare_attention_input(prepare_attention_input)
[92m All tests passed