Siamese Networks With Trax
Table of Contents
Beginning
Imports
# pypi
from jax.interpreters.xla import _DeviceArray as DeviceArray
from trax import layers
import numpy
import trax
import trax.fastmath.numpy as fast_numpy
Middle
L2 Normalization
Before building the model you will need to define a function that applies L2 normalization to a tensor. Luckily this is pretty straightforward.
def normalize(x: numpy.ndarray) -> DeviceArray:
"""L2 Normalization
Args:
x: the data to normalize
Returns:
normalized version of x
"""
return x / fast_numpy.sqrt(fast_numpy.sum(x * x, axis=-1, keepdims=True))
The denominator can be replaced by np.linalg.norm(x, axis
-1, keepdims=True)= to achieve the same result.
tensor = numpy.random.random((2,5))
print(f'The tensor is of type: {type(tensor)}\n\nAnd looks like this:\n\n {tensor}')
The tensor is of type: <class 'numpy.ndarray'> And looks like this: [[0.68535982 0.95339335 0.00394827 0.51219226 0.81262096] [0.61252607 0.72175532 0.29187607 0.91777412 0.71457578]]
norm_tensor = normalize(tensor)
print(f'The normalized tensor is of type: {type(norm_tensor)}\n\nAnd looks like this:\n\n {norm_tensor}')
The normalized tensor is of type: <class 'jax.interpreters.xla._DeviceArray'> And looks like this: [[0.45177674 0.6284596 0.00260263 0.33762783 0.535665 ] [0.40091467 0.47240815 0.1910407 0.6007077 0.46770892]]
Notice that the initial tensor was converted from a numpy array to a jax array in the process.
The Siamese Model
To create a Siamese
model you will first need to create a LSTM model using the Serial
combinator layer and then use another combinator layer called Parallel
to create the Siamese model. You should be familiar with the following layers:
Serial
: A combinator layer that allows to stack layers serially using functioncomposition.Embedding
: Maps discrete tokens to vectors. It will have shape(vocabulary length X dimension of output vectors)
. The dimension of output vectors (also calledd_feature
) is the number of elements in the word embedding.
-LSTM
: The LSTM layer. It leverages another Trax layer called LSTMCell
. The number of units should be specified and should match the number of elements in the word embedding.
Mean
Computes the mean across a desired axis. Mean uses one tensor axis to form groups of values and replaces each group with the mean value of that group.Fn
Layer with no weights that applies the function f, which should be specified using a lambda syntax.Parallel
It is a combinator layer (likeSerial
) that applies a list of layers in parallel to its inputs.
Putting everything together the Siamese model looks like this:
vocab_size = 500
model_dimension = 128
# Define the LSTM model
LSTM = layers.Serial(
layers.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
layers.LSTM(model_dimension),
layers.Mean(axis=1),
layers.Fn('Normalize', lambda x: normalize(x))
)
# Use the Parallel combinator to create a Siamese model out of the LSTM
Siamese = layers.Parallel(LSTM, LSTM)
Next is a helper function that prints information for every layer (sublayer within Serial
):
def show_layers(model, layer_prefix):
print(f"Total layers: {len(model.sublayers)}\n")
for i in range(len(model.sublayers)):
print('========')
print(f'{layer_prefix}_{i}: {model.sublayers[i]}\n')
print('Siamese model:\n')
show_layers(Siamese, 'Parallel.sublayers')
Siamese model: Total layers: 2 ======== Parallel.sublayers_0: Serial[ Embedding_500_128 LSTM_128 Mean Normalize ] ======== Parallel.sublayers_1: Serial[ Embedding_500_128 LSTM_128 Mean Normalize ]
print('Detail of LSTM models:\n')
show_layers(LSTM, 'Serial.sublayers')
Detail of LSTM models: Total layers: 4 ======== Serial.sublayers_0: Embedding_500_128 ======== Serial.sublayers_1: LSTM_128 ======== Serial.sublayers_2: Mean ======== Serial.sublayers_3: Normalize