Jax, Numpy, and Perplexity

Beginning

Imports

Note to future self: The default jax installation from pip is CPU only, to get it to run on the GPU (which seems to be the main reason to use it) you need to specify it. Right now the command is:

pip install jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Where cuda111 refers to the fact that I have cuda 11.1 installed on the server, so I need that version. See the installation instructions for more information (and to see if anything changes).

# from python
from argparse import Namespace
from pathlib import Path

import os

# from pypi
from dotenv import load_dotenv
from trax import layers

import numpy
import trax
import trax.fastmath.numpy as trax_numpy

Set Up

The Data Paths

load_dotenv("posts/nlp/.env", override=True)
Paths = Namespace(
    targets=Path(os.environ["RNN_TARGETS"]).expanduser(),
    predictions=Path(os.environ["RNN_PREDICTIONS"]).expanduser()
)

The Random Seed

SEED = 32

# trax no longer has a global seed setting - pass it to the training.Loop
# trax.supervised.trainer_lib.init_random_number_generators(SEED)
numpy.random.seed(SEED)

Middle

Numpy vs Trax

One important change to take into consideration is that the types of the resulting objects will be different depending on the version of numpy. With regular numpy you get numpy.ndarray but with Trax's numpy you will get jax.interpreters.xla.DeviceArray. These two types map to each other. So if you find some error logs mentioning DeviceArray type, don't worry about it, treat it like you would treat an ndarray and march ahead.

You can get a randomized numpy array by using the numpy.random.random() function.

This is one of the functionalities that Trax's numpy does not currently support in the same way as the regular numpy.

numpy_array = numpy.random.random((5,10))
print(f"The regular numpy array looks like this:\n\n {numpy_array}\n")
print(f"It is of type: {type(numpy_array)}")
The regular numpy array looks like this:

 [[0.85888927 0.37271115 0.55512878 0.95565655 0.7366696  0.81620514
  0.10108656 0.92848807 0.60910917 0.59655344]
 [0.09178413 0.34518624 0.66275252 0.44171349 0.55148779 0.70371249
  0.58940123 0.04993276 0.56179184 0.76635847]
 [0.91090833 0.09290995 0.90252139 0.46096041 0.45201847 0.99942549
  0.16242374 0.70937058 0.16062408 0.81077677]
 [0.03514717 0.53488673 0.16650012 0.30841038 0.04506241 0.23857613
  0.67483453 0.78238275 0.69520163 0.32895445]
 [0.49403187 0.52412136 0.29854125 0.46310814 0.98478429 0.50113492
  0.39807245 0.72790532 0.86333097 0.02616954]]

It is of type: <class 'numpy.ndarray'>

You can easily cast regular numpy arrays or lists into trax numpy arrays using the trax.fastmath.numpy.array() function:

trax_numpy_array = trax_numpy.array(numpy_array)
print(f"The trax numpy array looks like this:\n\n {trax_numpy_array}\n")
print(f"It is of type: {type(trax_numpy_array)}")
The trax numpy array looks like this:

 [[0.8588893  0.37271115 0.55512875 0.9556565  0.7366696  0.81620514
  0.10108656 0.9284881  0.60910916 0.59655344]
 [0.09178413 0.34518623 0.6627525  0.44171348 0.5514878  0.70371246
  0.58940125 0.04993276 0.56179184 0.7663585 ]
 [0.91090834 0.09290995 0.9025214  0.46096042 0.45201847 0.9994255
  0.16242374 0.7093706  0.16062407 0.81077677]
 [0.03514718 0.5348867  0.16650012 0.30841038 0.04506241 0.23857613
  0.67483455 0.7823827  0.69520164 0.32895446]
 [0.49403188 0.52412134 0.29854125 0.46310815 0.9847843  0.50113493
  0.39807245 0.72790533 0.86333096 0.02616954]]

It is of type: <class 'jax.interpreters.xla._DeviceArray'>

The previous section was a quick look at Trax's numpy. However this notebook also aims to teach you how you can calculate the perplexity of a trained model.

Calculating Perplexity

The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as:

\[ P(W) = \sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}} \]

As an implementation hack, you would usually take the log of that formula (to enable us to use the log probabilities we get as output of our RNN, convert exponents to products, and products into sums which makes computations less complicated and computationally more efficient). You should also take care of the padding, since you do not want to include the padding when calculating the perplexity (because we do not want to have a perplexity measure artificially good). The algebra behind this process is explained next:

\begin{align} log P(W) &= {log\left(\sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}}\right)} \\ &= {log\left({\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}}\right)^{\frac{1}{N}}} \\ &= {log\left({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\right)^{-\frac{1}{N}}} \\ &= -\frac{1}{N}{log\left({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\right)} \\ &= -\frac{1}{N}{\left({\sum_{i=1}^{N}{logP(w_i| w_1,...,w_{n-1})}}\right)} \end{align}

We're going to use some pre-made arrays.

predictions = numpy.load(Paths.predictions)
targets = numpy.load(Paths.targets)

Now we'll cast the numpy arrays to jax.interpreters.xla.DeviceArrays.

predictions = trax_numpy.array(predictions)
targets = trax_numpy.array(targets)
print(f'predictions has shape: {predictions.shape}')
print(f'targets has shape: {targets.shape}')
predictions has shape: (32, 64, 256)
targets has shape: (32, 64)

Notice that the predictions have an extra dimension - this is the same length as the size of the vocabulary used. Because of this you will need a way of reshaping targets to match this shape. For this we will use trax.layers.one_hot.

Also note that we can get the size of the last dimension using predictions.shape[-1].

reshaped_targets = layers.one_hot(x=targets, n_categories=predictions.shape[-1])
print(f'reshaped_targets has shape: {reshaped_targets.shape}')
reshaped_targets has shape: (32, 64, 256)

By calculating the product of the predictions and the reshaped targets and summing across the last dimension, we can compute the total log perplexity.

total_log_perplexity = trax_numpy.sum(predictions * reshaped_targets, axis= -1)

Now you will need to account for the padding so this metric is not artificially deflated (since a lower perplexity means a better model). To identify which elements are padding and which are not, you can use np.equal() and get a tensor with True in the positions of actual values and False where there are paddings.

equals_zero = trax_numpy.equal(targets, 0)
print(equals_zero)
[[False False False ...  True  True  True]
 [False False False ...  True  True  True]
 [False False False ...  True  True  True]
 ...
 [False False False ...  True  True  True]
 [False False False ...  True  True  True]
 [False False False ...  True  True  True]]

equals_zero is a boolean array that has True wherever the cell had a 0 and False everywhere else. To make it numeric we can subtract the boolean array from 1 (generally in python True is treated as 1 and False as 0).

non_pad = 1.0 - equals_zero
print(f'non_pad has shape: {non_pad.shape}\n')
print(f'non_pad looks like this: \n\n {non_pad}')
non_pad has shape: (32, 64)

non_pad looks like this: 

 [[1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]]

Now if we multiply total_log_perplexity by the non_pad we'll zero-out all the entries in total_log_perplexity where non_pad has zero.

real_log_perplexity = total_log_perplexity * non_pad
print(f'real perplexity still has shape: {real_log_perplexity.shape}')
real perplexity still has shape: (32, 64)

We can check the effect of filtering out the padding by looking at the two log perplexity tensors.

print(f'log perplexity tensor before filtering padding: \n\n {total_log_perplexity}\n')
print(f'log perplexity tensor after filtering padding: \n\n {real_log_perplexity}')
log perplexity tensor before filtering padding: 

 [[ -5.396545    -1.0311184   -0.66916656 ... -22.37673    -23.18771
  -21.843483  ]
 [ -4.5857706   -1.1341286   -8.538033   ... -20.15686    -26.837097
  -23.57502   ]
 [ -5.2223887   -1.2824144   -0.17312431 ... -21.328228   -19.854412
  -33.88444   ]
 ...
 [ -5.396545   -17.291681    -4.360766   ... -20.825802   -21.065838
  -22.443115  ]
 [ -5.9313164  -14.247417    -0.2637329  ... -26.743248   -18.38433
  -22.355278  ]
 [ -5.670536    -0.10595131   0.         ... -23.332523   -28.087376
  -23.878807  ]]

log perplexity tensor after filtering padding: 

 [[ -5.396545    -1.0311184   -0.66916656 ...  -0.          -0.
   -0.        ]
 [ -4.5857706   -1.1341286   -8.538033   ...  -0.          -0.
   -0.        ]
 [ -5.2223887   -1.2824144   -0.17312431 ...  -0.          -0.
   -0.        ]
 ...
 [ -5.396545   -17.291681    -4.360766   ...  -0.          -0.
   -0.        ]
 [ -5.9313164  -14.247417    -0.2637329  ...  -0.          -0.
   -0.        ]
 [ -5.670536    -0.10595131   0.         ...  -0.          -0.
   -0.        ]]

To get a single average log perplexity across all the elements in the batch you can sum across both dimensions and divide by the number of elements. Note that the result will be the negative of the real log perplexity of the model.

log_perplexity = -trax_numpy.sum(real_log_perplexity) / trax_numpy.sum(non_pad)
print(f"log perplexity: {log_perplexity:0.4f}, "
      f"perplexity: {trax_numpy.exp(log_perplexity):0.4f}")
log perplexity: 2.3281, perplexity: 10.2586