Trax GRU Model
Table of Contents
Creating a GRU Model Using Trax
Imports
# from pypi
from trax import layers
import trax
Middle
Trax Review
Trax allows us to define neural network architectures by stacking layers (similarly to other libraries such as Keras). For this the Serial()
is often used as it is a combinator that allows us to stack layers serially using function composition.
Next we'll look at a simple vanilla NN architecture containing 1 hidden(dense) layer with 128 cells and output (dense) layer with 10 cells on which we apply the final layer of LogSoftMax
.
simple = layers.Serial(
layers.Dense(128),
layers.Relu(),
layers.Dense(10),
layers.LogSoftmax()
)
Each of the layers within the Serial
combinator layer is considered a sublayer. Notice that unlike similar libraries, in Trax the activation functions are considered layers. To know more about the Serial
layer check out the documentation for it.
Here's the representation for it.
print(simple)
Serial[ Dense_128 Serial[ Relu ] Dense_10 LogSoftmax ]
Printing the model gives you the exact same information as the model's definition itself.
By just looking at the definition you can clearly see what is going on inside the neural network. Trax is very straightforward in the way a network is defined.
The GRU Model
To create a GRU
model you will need to be familiar with the following layers (Documentation link attached with each layer name):
ShiftRight
: Shifts the tensor to the right by padding on axis 1. Themode
should be specified and it refers to the context in which the model is being used. Possible values are: 'train', 'eval' or 'predict', predict mode is for fast inference. Defaults to "train".Embedding
Maps discrete tokens to vectors. It will have shape(vocabulary length X dimension of output vectors)
. The dimension of output vectors (also calledd_feature
) is the number of elements in the word embedding.GRU
The GRU layer. It leverages another Trax layer calledGRUCell
. The number of GRU units should be specified and should match the number of elements in the word embedding. If you want to stack two consecutive GRU layers, it can be done by using python's list comprehension.Dense
Vanilla Dense layer.LogSoftMax
Log Softmax function.
Putting everything together the GRU model looks like this.
mode = 'train'
vocab_size = 256
model_dimension = 512
n_layers = 2
GRU = layers.Serial(
layers.ShiftRight(mode=mode),
layers.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
[layers.GRU(n_units=model_dimension) for _ in range(n_layers)],
layers.Dense(n_units=vocab_size),
layers.LogSoftmax()
)
Next is a helper function that prints information for every layer (sublayer within Serial
).
Try changing the parameters defined before the GRU model and see how it changes.
def show_layers(model, layer_prefix="Serial.sublayers"):
print(f"Total layers: {len(model.sublayers)}\n")
for i in range(len(model.sublayers)):
print('========')
print(f'{layer_prefix}_{i}: {model.sublayers[i]}\n')
show_layers(GRU)
Total layers: 6 ======== Serial.sublayers_0: Serial[ ShiftRight(1) ] ======== Serial.sublayers_1: Embedding_256_512 ======== Serial.sublayers_2: GRU_512 ======== Serial.sublayers_3: GRU_512 ======== Serial.sublayers_4: Dense_256 ======== Serial.sublayers_5: LogSoftmax
print(GRU)
Serial[ Serial[ ShiftRight(1) ] Embedding_256_512 GRU_512 GRU_512 Dense_256 LogSoftmax ]
Interesting that it inserted a second Serial for the ShiftRight…