Trax GRU Model

Creating a GRU Model Using Trax

Imports

# from pypi
from trax import layers
import trax

Middle

Trax Review

Trax allows us to define neural network architectures by stacking layers (similarly to other libraries such as Keras). For this the Serial() is often used as it is a combinator that allows us to stack layers serially using function composition.

Next we'll look at a simple vanilla NN architecture containing 1 hidden(dense) layer with 128 cells and output (dense) layer with 10 cells on which we apply the final layer of LogSoftMax.

simple = layers.Serial(
  layers.Dense(128),
  layers.Relu(),
  layers.Dense(10),
  layers.LogSoftmax()
)

Each of the layers within the Serial combinator layer is considered a sublayer. Notice that unlike similar libraries, in Trax the activation functions are considered layers. To know more about the Serial layer check out the documentation for it.

Here's the representation for it.

print(simple)
Serial[
  Dense_128
  Serial[
    Relu
  ]
  Dense_10
  LogSoftmax
]

Printing the model gives you the exact same information as the model's definition itself.

By just looking at the definition you can clearly see what is going on inside the neural network. Trax is very straightforward in the way a network is defined.

The GRU Model

To create a GRU model you will need to be familiar with the following layers (Documentation link attached with each layer name):

  • ShiftRight: Shifts the tensor to the right by padding on axis 1. The mode should be specified and it refers to the context in which the model is being used. Possible values are: 'train', 'eval' or 'predict', predict mode is for fast inference. Defaults to "train".
  • Embedding Maps discrete tokens to vectors. It will have shape (vocabulary length X dimension of output vectors). The dimension of output vectors (also called d_feature) is the number of elements in the word embedding.
  • GRU The GRU layer. It leverages another Trax layer called GRUCell. The number of GRU units should be specified and should match the number of elements in the word embedding. If you want to stack two consecutive GRU layers, it can be done by using python's list comprehension.
  • Dense Vanilla Dense layer.
  • LogSoftMax Log Softmax function.

Putting everything together the GRU model looks like this.

mode = 'train'
vocab_size = 256
model_dimension = 512
n_layers = 2

GRU = layers.Serial(
      layers.ShiftRight(mode=mode),
      layers.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
      [layers.GRU(n_units=model_dimension) for _ in range(n_layers)],
      layers.Dense(n_units=vocab_size),
      layers.LogSoftmax()
    )

Next is a helper function that prints information for every layer (sublayer within Serial).

Try changing the parameters defined before the GRU model and see how it changes.

def show_layers(model, layer_prefix="Serial.sublayers"):
    print(f"Total layers: {len(model.sublayers)}\n")
    for i in range(len(model.sublayers)):
        print('========')
        print(f'{layer_prefix}_{i}: {model.sublayers[i]}\n')
show_layers(GRU)
Total layers: 6

========
Serial.sublayers_0: Serial[
  ShiftRight(1)
]

========
Serial.sublayers_1: Embedding_256_512

========
Serial.sublayers_2: GRU_512

========
Serial.sublayers_3: GRU_512

========
Serial.sublayers_4: Dense_256

========
Serial.sublayers_5: LogSoftmax
print(GRU)
Serial[
  Serial[
    ShiftRight(1)
  ]
  Embedding_256_512
  GRU_512
  GRU_512
  Dense_256
  LogSoftmax
]

Interesting that it inserted a second Serial for the ShiftRight…