Siamese Networks With Trax



# pypi
from jax.interpreters.xla import _DeviceArray as DeviceArray
from trax import layers

import numpy
import trax
import trax.fastmath.numpy as fast_numpy


L2 Normalization

Before building the model you will need to define a function that applies L2 normalization to a tensor. Luckily this is pretty straightforward.

def normalize(x: numpy.ndarray) -> DeviceArray:
    """L2 Normalization

     x: the data to normalize

     normalized version of x
    return x / fast_numpy.sqrt(fast_numpy.sum(x * x, axis=-1, keepdims=True))

The denominator can be replaced by np.linalg.norm(x, axis-1, keepdims=True)= to achieve the same result.

tensor = numpy.random.random((2,5))
print(f'The tensor is of type: {type(tensor)}\n\nAnd looks like this:\n\n {tensor}')
The tensor is of type: <class 'numpy.ndarray'>

And looks like this:

 [[0.68535982 0.95339335 0.00394827 0.51219226 0.81262096]
 [0.61252607 0.72175532 0.29187607 0.91777412 0.71457578]]
norm_tensor = normalize(tensor)
print(f'The normalized tensor is of type: {type(norm_tensor)}\n\nAnd looks like this:\n\n {norm_tensor}')
The normalized tensor is of type: <class 'jax.interpreters.xla._DeviceArray'>

And looks like this:

 [[0.45177674 0.6284596  0.00260263 0.33762783 0.535665  ]
 [0.40091467 0.47240815 0.1910407  0.6007077  0.46770892]]

Notice that the initial tensor was converted from a numpy array to a jax array in the process.

The Siamese Model

To create a Siamese model you will first need to create a LSTM model using the Serial combinator layer and then use another combinator layer called Parallel to create the Siamese model. You should be familiar with the following layers:

  • Serial : A combinator layer that allows to stack layers serially using functioncomposition.
  • Embedding : Maps discrete tokens to vectors. It will have shape (vocabulary length X dimension of output vectors). The dimension of output vectors (also called d_feature) is the number of elements in the word embedding.

-LSTM : The LSTM layer. It leverages another Trax layer called LSTMCell. The number of units should be specified and should match the number of elements in the word embedding.

  • Mean Computes the mean across a desired axis. Mean uses one tensor axis to form groups of values and replaces each group with the mean value of that group.
  • Fn Layer with no weights that applies the function f, which should be specified using a lambda syntax.
  • Parallel It is a combinator layer (like Serial) that applies a list of layers in parallel to its inputs.

Putting everything together the Siamese model looks like this:

vocab_size = 500
model_dimension = 128

# Define the LSTM model
LSTM = layers.Serial(
        layers.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
        layers.Fn('Normalize', lambda x: normalize(x))

# Use the Parallel combinator to create a Siamese model out of the LSTM 
Siamese = layers.Parallel(LSTM, LSTM)

Next is a helper function that prints information for every layer (sublayer within Serial):

def show_layers(model, layer_prefix):
    print(f"Total layers: {len(model.sublayers)}\n")
    for i in range(len(model.sublayers)):
        print(f'{layer_prefix}_{i}: {model.sublayers[i]}\n')
print('Siamese model:\n')
show_layers(Siamese, 'Parallel.sublayers')
Siamese model:

Total layers: 2

Parallel.sublayers_0: Serial[

Parallel.sublayers_1: Serial[
print('Detail of LSTM models:\n')
show_layers(LSTM, 'Serial.sublayers')
Detail of LSTM models:

Total layers: 4

Serial.sublayers_0: Embedding_500_128

Serial.sublayers_1: LSTM_128

Serial.sublayers_2: Mean

Serial.sublayers_3: Normalize
