Introducing Trax


This is going to be a first look at Trax a Deep Learning framework built by the Google Brain team.

Why Trax and not TensorFlow or PyTorch?

TensorFlow and PyTorch are both extensive frameworks that can do almost anything in deep learning. They offer a lot of flexibility, but that often means verbosity of syntax and extra time to code.

Trax is much more concise. It runs on a TensorFlow backend but allows you to train models with 1 line commands. Trax also runs end to end, allowing you to get data, model and train all with a single terse statement. This means you can focus on learning, instead of spending hours on the idiosyncrasies of a big framework's implementation.

Why not Keras then?

Keras is now part of Tensorflow itself from 2.0 onwards. Also, trax is good for implementing new state of the art algorithms like Transformers, Reformers, BERT because it is actively maintained by Google Brain Team for advanced deep learning tasks. It runs smoothly on CPUs,GPUs and TPUs as well with comparatively lesser modifications in code.

How to Code in Trax

Building models in Trax relies on 2 key concepts:- layers and combinators. Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax combinators, allowing you to build layers and models of any complexity.

Trax, JAX, TensorFlow and Tensor2Tensor

You already know that Trax uses Tensorflow as a backend, but it also uses the JAX library to speed up computation too. You can view JAX as an enhanced and optimized version of numpy.

You import their version of numpy using import trax.fastmath.numpy. If you see this line, remember that when calling numpy you are really calling Trax’s version of numpy that is compatible with JAX.**

As a result of this, where you used to encounter the type numpy.ndarray now you will find the type jax.interpreters.xla.DeviceArray. The documentation for JAX is here and specifically they have a page with the numpy functions implemented so far.

Tensor2Tensor is another name you might have heard. It started as an end to end solution much like how Trax is designed, but it grew unwieldy and complicated. So you can view Trax as the new improved version that operates much faster and simpler.

Installing Trax

Note that there is another library called TraX which is something different.

We're going to use Trax version 1.3.1 here, so to install it with pip:

pip install trax==1.3.1

Note the == for the version, not =. This is a very big install so maybe take a break after you run it. You aren't going to get the full benefit of JAX if you don't have CUDA set up can use TPUs so make sure to set up CUDA if you're not using google colab. I also had to install cmake to get trax to install.


# pypi
import numpy

from trax import layers
from trax import shapes
from trax import fastmath
  • Layers are the basic building blocks for Trax
  • shapes are used for data handling
  • fastmath is the JAX version of numpy that can run on GPUs and TPUs



Layers are the core building blocks in Trax - they are the base classes. They take inputs, compute functions/custom calculations and return outputs.

Relu Layer

First we'll build a ReLU activation function as a layer. A layer like this is one of the simplest types. Notice there is no object initialization so it works just like a math function.

Note: Activation functions are also layers in Trax, which might look odd if you have been using other frameworks for a longer time.

relu = layers.Relu()

You can inspect the properties of a layer:

print("-- Properties --")
print("name :",
print("expected inputs :", relu.n_in)
print("promised outputs :", relu.n_out, "\n")
-- Properties --
name : Relu
expected inputs : 1
promised outputs : 1 

We'll make an input the layer using numpy.

x = numpy.array([-2, -1, 0, 1, 2])
print("-- Inputs --")
print("x :", x, "\n")
-- Inputs --
x : [-2 -1  0  1  2] 

And see what it puts out.

y = relu(x)
print("-- Outputs --")
print("y :", y)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
-- Outputs --
y : [0 0 0 1 2]

I don't know why but JAX doesn't thing I have a GPU, even though tensorflow does. This whole thing is a little messed up right now because the current release of tensorflow doesn't work on Ubuntu 20.10. I'm running it with the nightly build (2.5) but I have to install all the Trax dependencies one at a time or it will clobber the tensorflow installation with the older version (the one that doesn't work) so there's a lot of places for error.

Concatenate Layer

Now a layer that takes 2 inputs. Notice the change in the expected inputs property from 1 to 2.

First create a concatenate trax layer and check out its properties.

concatenate = layers.Concatenate()
print("-- Properties --")
print("name :",
print("expected inputs :", concatenate.n_in)
print("promised outputs :", concatenate.n_out, "\n")
-- Properties --
name : Concatenate
expected inputs : 2
promised outputs : 1 

Now create the two inputs.

x1 = numpy.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")
-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.] 

And now feed the inputs through the concatenate layer.

y = concatenate([x1, x2])
print("-- Outputs --")
print("y :", y)
-- Outputs --
y : [-10. -20. -30.   1.   2.   3.]

Configuring Layers

You can change the default settings of layers. For example, you can change the expected inputs for a concatenate layer from 2 to 3 using the optional parameter n_items.

concatenate_three = layers.Concatenate(n_items=3)
print("-- Properties --")
print("name :",
print("expected inputs :", concatenate_three.n_in)
print("promised outputs :", concatenate_three.n_out, "\n")
-- Properties --
name : Concatenate
expected inputs : 3
promised outputs : 1 

Create some inputs.

x1 = numpy.array([-10, -20, -30])
x2 = x1 / -10
x3 = x2 * 0.99
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3, "\n")
-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
x3 : [0.99 1.98 2.97] 

And now do the concatenation.

y = concatenate_three([x1, x2, x3])
print("-- Outputs --")
print("y :", y)
-- Outputs --
y : [-10.   -20.   -30.     1.     2.     3.     0.99   1.98   2.97]

Layer Weights

Some layer types include mutable weights and biases that are used in computation and training. Layers of this type require initialization before use.

For example the LayerNorm layer calculates normalized data, that is also scaled by weights and biases. During initialization you pass the data shape and data type of the inputs, so the layer can initialize compatible arrays of weights and biases.

Initialize it.

norm = layers.LayerNorm()

Now some input data.

x = numpy.array([0, 1, 2, 3], dtype="float")

Use the input data signature to get the shape and type for the initializing weights and biases. We need to convert the input datatype from the usual ndarray to a trax ShapeDtype

print("Normal shape:",x.shape, "Data Type:",type(x.shape))
print("Shapes Trax:",shapes.signature(x),"Data Type:",type(shapes.signature(x)))
Normal shape: (4,) Data Type: <class 'tuple'>
Shapes Trax: ShapeDtype{shape:(4,), dtype:float64} Data Type: <class 'trax.shapes.ShapeDtype'>

Here are its properties.

print("-- Properties --")
print("name :",
print("expected inputs :", norm.n_in)
print("promised outputs :", norm.n_out)
-- Properties --
name : LayerNorm
expected inputs : 1
promised outputs : 1

And the weights and biases.

print("weights :", norm.weights[0])
print("biases :", norm.weights[1],)
weights : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.]

We have our input array.

print("-- Inputs --")
print("x :", x)
-- Inputs --
x : [0. 1. 2. 3.]

So we can inspect what the layer did to it.

y = norm(x)
print("-- Outputs --")
print("y :", y)
-- Outputs --
y : [-1.3416404  -0.44721344  0.44721344  1.3416404 ]

If you look at it you can see that the positives cancel out the negatives, giving us a sum of 0. I don't know why that's the norm, but maybe it'll become obvious later.

Custom Layers

You can create your own custom layers too and define custom functions for computations by using layers.Fn. Let me show you how.

Help on function Fn in module trax.layers.base:

Fn(name, f, n_out=1)
    Returns a layer with no weights that applies the function `f`.
    `f` can take and return any number of arguments, and takes only positional
    arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
    The following, for example, would create a layer that takes two inputs and
    returns two outputs -- element-wise sums and maxima:
        `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
    The layer's number of inputs (`n_in`) is automatically set to number of
    positional arguments in `f`, but you must explicitly set the number of
    outputs (`n_out`) whenever it's not the default value 1.
      name: Class-like name for the resulting layer; for use in debugging.
      f: Pure function from input tensors to output tensors, where each input
          tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
          Output tensors must be packaged as specified in the `Layer` class
      n_out: Number of outputs promised by the layer; default value 1.
      Layer executing the function `f`.
  • Define a custom layer

    In this example we'll create a layer to calculate the input times 2.

    def double_it() -> layers.Fn:
        """A custom layer function that doubles any inputs
         a custom function that takes one numeric argument and doubles it
        layer_name = "TimesTwo"
        # Custom function for the custom layer
        def func(x):
            return x * 2
        return layers.Fn(layer_name, func)
  • Test it
    double = double_it()
    print("-- Properties --")
    print("name :",
    print("expected inputs :", double.n_in)
    print("promised outputs :", double.n_out)
    -- Properties --
    name : TimesTwo
    expected inputs : 1
    promised outputs : 1
    x = numpy.array([1, 2, 3])
    print("-- Inputs --")
    print("x :", x, "\n")
    y = double(x)
    print("-- Outputs --")
    print("y :", y)
    -- Inputs --
    x : [1 2 3] 
    -- Outputs --
    y : [2 4 6]


You can combine layers to build more complex layers. Trax provides a set of objects named combinator layers to make this happen. Combinators are themselves layers, so behavior commutes.

Serial Combinator

This is the most common and easiest to use. You could, for example, build a simple neural network by combining layers into a single layer using the Serial combinator. This new layer then acts just like a single layer, so you can inspect intputs, outputs and weights. Or even combine it into another layer! Combinators can then be used as trainable models. Try adding more layers.

Note:As you must have guessed, if there is serial combinator, there must be a parallel combinator as well. Do try to explore about combinators and other layers from the trax documentation and look at the repo to understand how these layers are written.

serial = layers.Serial(
  • Initialization
    x = numpy.array([-2, -1, 0, 1, 2]) #input
    print("-- Serial Model --")
    print("-- Properties --")
    print("name :",
    print("sublayers :", serial.sublayers)
    print("expected inputs :", serial.n_in)
    print("promised outputs :", serial.n_out)
    print("weights & biases:", serial.weights, "\n")
    -- Serial Model --
    -- Properties --
    name : Serial
    sublayers : [LayerNorm, Relu, TimesTwo, Dense_2, Dense_1, LogSoftmax]
    expected inputs : 1
    promised outputs : 1
    weights & biases: [(DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), (), (), (DeviceArray([[ 0.19178385,  0.1832077 ],
                 [-0.36949775, -0.03924937],
                 [ 0.43800744,  0.788491  ],
                 [ 0.43107533, -0.3623491 ],
                 [ 0.6186575 ,  0.04764405]], dtype=float32), DeviceArray([-3.0051979e-06,  1.4359505e-06], dtype=float32)), (DeviceArray([[-0.6747592],
                 [-0.8550365]], dtype=float32), DeviceArray([-8.9325863e-07], dtype=float32)), ()] 
    print("-- Inputs --")
    print("x :", x, "\n")
    y = serial(x)
    print("-- Outputs --")
    print("y :", y)
    -- Inputs --
    x : [-2 -1  0  1  2] 
    -- Outputs --
    y : [0.]


Just remember to lookout for which numpy you are using, the regular numpy or Trax's JAX compatible numpy. Watch those import blocks. Numpy and fastmath.numpy have different data types.

Regular numpy.

x_numpy = numpy.array([1, 2, 3])
print("good old numpy : ", type(x_numpy), "\n")
good old numpy :  <class 'numpy.ndarray'> 

Fastmath and jax numpy.

x_jax = fastmath.numpy.array([1, 2, 3])
print("jax trax numpy : ", type(x_jax))
jax trax numpy :  <class 'jax.interpreters.xla._DeviceArray'>


  • Trax is a concise framework, built on TensorFlow, for end to end machine learning. The key building blocks are layers and combinators.
  • This was a lab that was part of coursera's Natural Language Processing with Sequence Models course put up by DeepLearning.AI.