Neural Machine Translation: The Attention Model
Table of Contents
<<imports>>
<<attention-model>>
Defining the Model
In the previous post we made some helper functions to prepare inputs for some of the layers in the model. In this post we'll define the model itself.
Attention Overview
The model we will be building uses an encoder-decoder architecture. This Recurrent Neural Network (RNN) will take in a tokenized version of a sentence in its encoder, then passes it on to the decoder for translation. Just using a a regular sequence-to-sequence model with LSTMs will work effectively for short to medium sentences but will start to degrade for longer ones. You can picture it like the figure below where all of the context of the input sentence is compressed into one vector that is passed into the decoder block. You can see how this will be an issue for very long sentences (e.g. 100 tokens or more) because the context of the first parts of the input will have very little effect on the final vector passed to the decoder.
Adding an attention layer to this model avoids this problem by giving the decoder access to all parts of the input sentence. To illustrate, let's just use a 4-word input sentence as shown below. Remember that a hidden state is produced at each timestep of the encoder (represented by the orange rectangles). These are all passed to the attention layer and each are given a score given the current activation (i.e. hidden state) of the decoder. For instance, let's consider the figure below where the first prediction "Wie" is already made. To produce the next prediction, the attention layer will first receive all the encoder hidden states (i.e. orange rectangles) as well as the decoder hidden state when producing the word "Wie" (i.e. first green rectangle). Given this information, it will score each of the encoder hidden states to know which one the decoder should focus on to produce the next word. The result of the model training might have learned that it should align to the second encoder hidden state and subsequently assigns a high probability to the word "geht". If we are using greedy decoding, we will output the said word as the next symbol, then restart the process to produce the next word until we reach an end-of-sentence prediction.
There are different ways to implement attention and the one we'll use is the Scaled Dot Product Attention which has the form:
\[ Attention(Q, K, V) = softmax \left(\frac{QK^T}{\sqrt{d_k}} \right)V \]
You can think of it as computing scores using queries (Q) and keys (K), followed by a multiplication of values (V) to get a context vector at a particular timestep of the decoder. This context vector is fed to the decoder RNN to get a set of probabilities for the next predicted word. The division by square root of the keys dimensionality (\(\sqrt{d_k}\)) is for improving model performance and you'll also learn more about it next week. For our machine translation application, the encoder activations (i.e. encoder hidden states) will be the keys and values, while the decoder activations (i.e. decoder hidden states) will be the queries.
You will see in the upcoming sections that this complex architecture and mechanism can be implemented with just a few lines of code.
Imports
# pypi
from trax import layers
import trax
# this project
from neurotic.nlp.machine_translation import (
NMTAttn)
Implementation
Overview
We are now ready to implement our sequence-to-sequence model with attention. This will be a Serial network and is illustrated in the diagram below. It shows the layers you'll be using in Trax and you'll see that each step can be implemented quite easily with one line commands. We've placed several links to the documentation for each relevant layer in the discussion after the figure below.
- Step 0: Prepare the input encoder and pre-attention decoder branches. We've already defined this earlier as helper functions so it's just a matter of calling those functions and assigning it to variables.
- Step 1: Create a Serial network. This will stack the layers in the next steps one after the other. As before, we'll use tl.Serial.
- Step 2: Make a copy of the input and target tokens. As you see in the diagram above, the input and target tokens will be fed into different layers of the model. We'll use tl.Select layer to create copies of these tokens, arranging them as
[input tokens, target tokens, input tokens, target tokens]
. - Step 3: Create a parallel branch to feed the input tokens to the
input_encoder
and the target tokens to thepre_attention_decoder
. We'll use tl.Parallel to create these sublayers in parallel, remembering to pass the variables defined in Step 0 as parameters to this layer. - Step 4: Next, call the `prepare_attention_input` function to convert the encoder and pre-attention decoder activations to a format that the attention layer will accept. You can use tl.Fn to call this function. Note: Pass the
prepare_attention_input
function as thef
parameter intl.Fn
without any arguments or parenthesis. - Step 5: We will now feed the (queries, keys, values, and mask) to the tl.AttentionQKV layer. This computes the scaled dot product attention and outputs the attention weights and mask. Take note that although it is a one liner, this layer is actually composed of a deep network made up of several branches. We'll show the implementation show here (on github) to see the different layers used.
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, mask) to (activations, mask). See `Attention` above for further context/details. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or 'eval'. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
Having deep layers poses the risk of vanishing gradients during training and we would want to mitigate that. To improve the ability of the network to learn, we can insert a tl.Residual layer to add the output of AttentionQKV with the queries
input. You can do this in trax by simply nesting the AttentionQKV
layer inside the Residual
layer. The library will take care of branching and adding for you.
- Step 6: We will not need the mask for the model we're building so we can safely drop it. At this point in the network, the signal stack currently has
[attention activations, mask, target tokens]
and you can use tl.Select to output just[attention activations, target tokens]
. - Step 7: We can now feed the attention weighted output to the LSTM decoder. We can stack multiple tl.LSTM layers to improve the output so remember to append LSTMs equal to the number defined by
n_decoder_layers
parameter to the model. - Step 8: We want to determine the probabilities of each subword in the vocabulary and you can set this up easily with a tl.Dense layer by making its size equal to the size of our vocabulary.
- Step 9: Normalize the output to log probabilities by passing the activations in Step 8 to a tl.LogSoftmax layer.
The Implementation
# pypi
from trax import layers
# this project
from .help_me import input_encoder as input_encoder_fn
from .help_me import pre_attention_decoder as pre_attention_decoder_fn
from .help_me import prepare_attention_input as prepare_attention_input_fn
def NMTAttn(input_vocab_size: int=33300,
target_vocab_size: int=33300,
d_model: int=1024,
n_encoder_layers: int=2,
n_decoder_layers: int=2,
n_attention_heads: int=4,
attention_dropout: float=0.0,
mode: str='train') -> layers.Serial:
"""Returns an LSTM sequence-to-sequence model with attention.
The input to the model is a pair (input tokens, target tokens), e.g.,
an English sentence (tokenized) and its translation into German (tokenized).
Args:
input_vocab_size: int: vocab size of the input
target_vocab_size: int: vocab size of the target
d_model: int: depth of embedding (n_units in the LSTM cell)
n_encoder_layers: int: number of LSTM layers in the encoder
n_decoder_layers: int: number of LSTM layers in the decoder after attention
n_attention_heads: int: number of attention heads
attention_dropout: float, dropout for the attention layer
mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference
Returns:
A LSTM sequence-to-sequence model with attention.
"""
# Step 0: call the helper function to create layers for the input encoder
input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers)
# Step 0: call the helper function to create layers for the pre-attention decoder
pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model)
# Step 1: create a serial network
model = layers.Serial(
# Step 2: copy input tokens and target tokens as they will be needed later.
layers.Select([0, 1, 0, 1]),
# Step 3: run input encoder on the input and pre-attention decoder on the target.
layers.Parallel(input_encoder, pre_attention_decoder),
# Step 4: prepare queries, keys, values and mask for attention.
layers.Fn('PrepareAttentionInput', prepare_attention_input_fn, n_out=4),
# Step 5: run the AttentionQKV layer
# nest it inside a Residual layer to add to the pre-attention decoder activations(i.e. queries)
layers.Residual(layers.AttentionQKV(d_model,
n_heads=n_attention_heads,
dropout=attention_dropout, mode=mode)),
# Step 6: drop attention mask (i.e. index = None
layers.Select([0, 2]),
# Step 7: run the rest of the RNN decoder
[layers.LSTM(d_model) for _ in range(n_decoder_layers)],
# Step 8: prepare output by making it the right size
layers.Dense(target_vocab_size),
# Step 9: Log-softmax for output
layers.LogSoftmax()
)
return model
def test_NMTAttn(NMTAttn):
test_cases = [
{
"name":"simple_test_check",
"expected":"Serial_in2_out2[\n Select[0,1,0,1]_in2_out4\n Parallel_in2_out2[\n Serial[\n Embedding_33300_1024\n LSTM_1024\n LSTM_1024\n ]\n Serial[\n ShiftRight(1)\n Embedding_33300_1024\n LSTM_1024\n ]\n ]\n PrepareAttentionInput_in3_out4\n Serial_in4_out2[\n Branch_in4_out3[\n None\n Serial_in4_out2[\n Parallel_in3_out3[\n Dense_1024\n Dense_1024\n Dense_1024\n ]\n PureAttention_in4_out2\n Dense_1024\n ]\n ]\n Add_in2\n ]\n Select[0,2]_in3_out2\n LSTM_1024\n LSTM_1024\n Dense_33300\n LogSoftmax\n]",
"error":"The NMTAttn is not defined properly."
},
{
"name":"layer_len_check",
"expected":9,
"error":"We found {} layers in your model. It should be 9.\nCheck the LSTM stack before the dense layer"
},
{
"name":"selection_layer_check",
"expected":["Select[0,1,0,1]_in2_out4", "Select[0,2]_in3_out2"],
"error":"Look at your selection layers."
}
]
success = 0
fails = 0
for test_case in test_cases:
try:
if test_case['name'] == "simple_test_check":
assert test_case["expected"] == str(NMTAttn())
success += 1
if test_case['name'] == "layer_len_check":
if test_case["expected"] == len(NMTAttn().sublayers):
success += 1
else:
print(test_case["error"].format(len(NMTAttn().sublayers)))
fails += 1
if test_case['name'] == "selection_layer_check":
model = NMTAttn()
output = [str(model.sublayers[0]),str(model.sublayers[4])]
check_count = 0
for i in range(2):
if test_case["expected"][i] != output[i]:
print(test_case["error"])
fails += 1
break
else:
check_count += 1
if check_count == 2:
success += 1
except:
print(test_case['error'])
fails += 1
if fails == 0:
print("\033[92m All tests passed")
else:
print('\033[92m', success," Tests passed")
print('\033[91m', fails, " Tests failed")
return test_cases
test_cases = test_NMTAttn(NMTAttn)
The NMTAttn is not defined properly. [92m 2 Tests passed [91m 1 Tests failed
model = NMTAttn()
print(model)
Serial_in2_out2[ Select[0,1,0,1]_in2_out4 Parallel_in2_out2[ Serial[ Embedding_33300_1024 LSTM_1024 LSTM_1024 ] Serial[ Serial[ ShiftRight(1) ] Embedding_33300_1024 LSTM_1024 ] ] PrepareAttentionInput_in3_out4 Serial_in4_out2[ Branch_in4_out3[ None Serial_in4_out2[ _in4_out4 Serial_in4_out2[ Parallel_in3_out3[ Dense_1024 Dense_1024 Dense_1024 ] PureAttention_in4_out2 Dense_1024 ] _in2_out2 ] ] Add_in2 ] Select[0,2]_in3_out2 LSTM_1024 LSTM_1024 Dense_33300 LogSoftmax ]