Stack Semantics

Stack Semantics in Trax

This will help in understanding how to use layers like Select and Residual which operate on elements in the stack. If you've taken a computer science class before, you will recall that a stack is a data structure that follows the Last In, First Out (LIFO) principle. That is, whatever is the latest element that is pushed into the stack will also be the first one to be popped out. If you're not yet familiar with stacks, then you may find this short tutorial useful. In a nutshell, all you really need to remember is it puts elements one on top of the other. You should be aware of what is on top of the stack to know which element you will be popping.

Imports

# pypi
import numpy
from trax import fastmath, layers, shapes

Middle

The Serial Combinator is Stack Oriented.

To understand how stack-orientation works in Trax, most times one will be using the Serial layer. We will define two simple Function layers:

  1. Addition
  2. Multiplication

Suppose we want to make the simple calculation \((3 + 4) \times 15 + 3\). We'll use Serial to perform the calculations in the following order 3 4 add 15 mul 3 add. The steps of the calculation are shown in the table below.

Stack Operations Stack
Push(4) 4
Push(3) 4 3
Push(Add Pop() Pop()) 7
Push(15) 7 15
Push(Mul Pop() Pop()) 105
Push(3) 105 3
Push(Add() Pop() Pop()) 108

The first column shows the operations made on the stack and the second column is what's on the stack. Moreover, the rightmost element in the second column represents the top of the stack (e.g. in the second row, Push(3) pushes 3 = on top of the stack and =4 is now under it).

After finishing the steps the stack contains 108 which is the answer to our simple computation.

From this, the following can be concluded: a stack-based layer has only one way to handle data, by taking one piece of data from atop the stack, called popping, and putting data back atop the stack, called pushing. Any expression that can be written conventionally, can be written this way and thus will be amenable to being interpreted by a stack-oriented layer like Serial.

Defining addition

We're going to define a trax function (FN) for addition.

def Addition():
    layer_name = "Addition" 

    def func(x, y):
        return x + y

    return layers.Fn(layer_name, func)

Test it out.

add = Addition()
print(type(add))
<class 'trax.layers.base.PureLayer'>
print("name :", add.name)
print("expected inputs :", add.n_in)
print("promised outputs :", add.n_out)
name : Addition
expected inputs : 2
promised outputs : 1
x = numpy.array([3])
y = numpy.array([4])

print(f"{x} + {y} = {add((x, y))}")
[3] + [4] = [7]

Defining multiplication

def Multiplication():
    layer_name = "Multiplication"

    def func(x, y):
        return x * y

    return layers.Fn(layer_name, func)

Test it out.

mul = Multiplication()

The properties.

print("name :", mul.name)
print("expected inputs :", mul.n_in)
print("promised outputs :", mul.n_out, "\n")
name : Multiplication
expected inputs : 2
promised outputs : 1 

Some Inputs.

x = numpy.array([7])
y = numpy.array([15])
print("x :", x)
print("y :", y)
x : [7]
y : [15]

The Output

z = mul((x, y))
print(f"{x} * {y} = {mul((x, y))}")
[7] * [15] = [105]

Implementing the computations using the Serial combinator

serial = layers.Serial(
    Addition(), Multiplication(), Addition()
)
inputs = (numpy.array([3]), numpy.array([4]), numpy.array([15]), numpy.array([3]))

serial.init(shapes.signature(inputs))
print(serial, "\n")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out, "\n")
Serial_in4[
  Addition_in2
  Multiplication_in2
  Addition_in2
] 

name : Serial
sublayers : [Addition_in2, Multiplication_in2, Addition_in2]
expected inputs : 4
promised outputs : 1 
print(f"{inputs} -> {serial(inputs)}")
(array([3]), array([4]), array([15]), array([3])) -> [108]

The example with the two simple adition and multiplication functions that where coded together with the serial combinator show how stack semantics work in Trax.

The tl.Select combinator in the context of the Serial combinator

Having understood how stack semantics work in Trax, we will demonstrate how the tl.Select combinator works.

First example of tl.Select

Suppose we want to make the simple calculation \((3 + 4) \times 3 + 4\). We can use Select to perform the calculations in the following manner:

  1. input 3 4
  2. tl.Select([0, 1, 0, 1])
  3. add
  4. mul
  5. add.

The tl.Select requires a list or tuple of 0-based indices to select elements relative to the top of the stack. For our example, the top of the stack is 3 (which is at index 0) then 4 (index 1) and we us Select to copy the top two elements of the stack and then push all four elements back onto the stack which after the command executes will now contain 3 4 3 4. The steps of the calculation for our example are shown in the table below. As in the previous table each column shows the contents of the stack and the outputs after the operations are carried out.

Stack Operations Stack
Push(4) 4
Push(3) 4 3
Push(Select([0, 1, 0, 1])) 4 3 4 3
Push(Add Pop() Pop()) 4 3 7
Push(Mul Pop() Pop()) 4 21
Push(Add Pop() Pop()) 25

After processing all the inputs the stack contains 25 which is the result of the calculations.

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]),
    Addition(),
    Multiplication(),
    Addition()
)

Now we'll create the input.

x = (numpy.array([3]), numpy.array([4]))
serial.init(shapes.signature(x))
print(serial, "\n")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out, "\n")
Serial_in2[
  Select[0,1,0,1]_in2_out4
  Addition_in2
  Multiplication_in2
  Addition_in2
] 

name : Serial
sublayers : [Select[0,1,0,1]_in2_out4, Addition_in2, Multiplication_in2, Addition_in2]
expected inputs : 2
promised outputs : 1 
print(f"{x} -> {serial(x)}")
(array([3]), array([4])) -> [25]

Select Makes It More Like a Collection

Note that since you are passing in indices to Select, you aren't really using it like a stack, even if behind the scenes it's using push and pop.

serial = layers.Serial(
    layers.Select([2, 1, 1, 2]),
    Addition(),
    Multiplication(),
    Addition()
)

x = (numpy.array([3]), numpy.array([4]), numpy.array([5]))
serial.init(shapes.signature(x))

print(f"{x} -> {serial(x)}")
(array([3]), array([4]), array([5])) -> [41]
print((5 + 4) * 4 + 5)
41

Another example of tl.Select

Suppose we want to make the simple calculation \((3 + 4) \times 4\). We can use Select to perform the calculations in the following manner:

  1. 4
  2. 3
  3. tl.Select([0,1,0,1])
  4. add
  5. tl.Select([0], n_in=2)
  6. mul

The example is a bit contrived but it demonstrates the flexibility of the command. The second tl.Select pops two elements (specified in n_in) from the stack starting from index 0 (i.e. top of the stack). This means that 7 and 3 = will be popped out because ~n_in = 2~) but only =7 is placed back on top because it only selects [0]. As in the previous table each column shows the contents of the stack and the outputs after the operations are carried out.

Stack Operations Outputs
Push(4) 4
Push(3) 4 3
Push(select([0, 1, 0, 1])) 4 3 4 3
Push(Add Pop() Pop()) 4 3 7
Push(select([0], n_in=2)) 7
Push(Mul Pop() Pop()) 28

After processing all the inputs the stack contains 28 which is the answer we get above.

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]),
    Addition(),
    layers.Select([0], n_in=2),
    Multiplication()
)
inputs = (numpy.array([3]), numpy.array([4]))
serial.init(shapes.signature(inputs))
print(serial, "\n")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
Serial_in2[
  Select[0,1,0,1]_in2_out4
  Addition_in2
  Select[0]_in2
  Multiplication_in2
] 

name : Serial
sublayers : [Select[0,1,0,1]_in2_out4, Addition_in2, Select[0]_in2, Multiplication_in2]
expected inputs : 2
promised outputs : 1
print(f"{inputs} -> {serial(inputs)}")
(array([3]), array([4])) -> [28]

In summary, what Select does in this example is make a copy of the inputs in order to be used further along in the stack of operations.

The tl.Residual combinator in the context of the Serial combinator

tl.Residual

Residual networks (that link is to a research paper, this is wikipedia)are frequently used to make deep models easier to train. Trax already has a built in layer for this. The Residual layer computes the element-wise sum of the stack-top input with the output of the layer series. Let's first see how it is used in the code below:

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]),
    layers.Residual(Addition())
)

print(serial, "\n")
print("name :", serial.name)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
Serial_in2_out3[
  Select[0,1,0,1]_in2_out4
  Serial_in2[
    Branch_in2_out2[
      None
      Addition_in2
    ]
    Add_in2
  ]
] 

name : Serial
expected inputs : 2
promised outputs : 3

Here, we use the Serial combinator to define our model. The inputs first goes through a Select layer, followed by a Residual layer which passes the Fn: Addition() layer as an argument. What this means is the Residual layer will take the stack top input at that point and add it to the output of the Fn: Addition() layer. You can picture it like the diagram the below, where x1 and x2 are the inputs to the model:

Now, let's try running our model with some sample inputs and see the result:

x1 = numpy.array([3])
x2 = numpy.array([4])

print(f"{x1} + {x2} -> {serial((x1, x2))}")
[3] + [4] -> (array([10]), array([3]), array([4]))

As you can see, the Residual layer remembers the stack top input (i.e. 3) and adds it to the result of the Fn: Addition() layer (i.e. 3 + 4 = 7). The output of Residual(Addition() is then 3 + 7 = 10 and is pushed onto the stack.

On a different note, you'll notice that the Select layer has 4 outputs but the Fn: Addition() layer only pops 2 inputs from the stack. This means the duplicate inputs (i.e. the 2 rightmost arrows of the Select outputs in the figure above) remain in the stack. This is why you still see it in the output of our simple serial network (i.e. array([3]), array([4])). This is useful if you want to use these duplicate inputs in another layer further down the network.

Modifying the network

To strengthen your understanding, you can modify the network above and examine the outputs you get. For example, you can pass the Fn: Multiplication() layer instead in the Residual block:

serial = layers.Serial(
    layers.Select([0, 1, 0, 1]), 
    layers.Residual(Multiplication())
)

print(serial, "\n")
print("name :", serial.name)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
Serial_in2_out3[
  Select[0,1,0,1]_in2_out4
  Serial_in2[
    Branch_in2_out2[
      None
      Multiplication_in2
    ]
    Add_in2
  ]
] 

name : Serial
expected inputs : 2
promised outputs : 3

This means you'll have a different output that will be added to the stack top input saved by the Residual block. The diagram becomes like this:

And you'll get 3 + (3 * 4) = 15 as output of the Residual block:

x1 = numpy.array([3])
x2 = numpy.array([4])

y = serial((x1, x2))
print(f"{x1} * {x2} -> {serial((x1, x2))}")
[3] * [4] -> (array([15]), array([3]), array([4]))