Jax, Numpy, and Perplexity
Table of Contents
Beginning
Imports
Note to future self: The default jax installation from pip
is CPU only, to get it to run on the GPU (which seems to be the main reason to use it) you need to specify it. Right now the command is:
pip install jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Where cuda111
refers to the fact that I have cuda 11.1 installed on the server, so I need that version. See the installation instructions for more information (and to see if anything changes).
# from python
from argparse import Namespace
from pathlib import Path
import os
# from pypi
from dotenv import load_dotenv
from trax import layers
import numpy
import trax
import trax.fastmath.numpy as trax_numpy
Set Up
The Data Paths
load_dotenv("posts/nlp/.env", override=True)
Paths = Namespace(
targets=Path(os.environ["RNN_TARGETS"]).expanduser(),
predictions=Path(os.environ["RNN_PREDICTIONS"]).expanduser()
)
The Random Seed
SEED = 32
# trax no longer has a global seed setting - pass it to the training.Loop
# trax.supervised.trainer_lib.init_random_number_generators(SEED)
numpy.random.seed(SEED)
Middle
Numpy vs Trax
One important change to take into consideration is that the types of the resulting objects will be different depending on the version of numpy. With regular numpy you get numpy.ndarray
but with Trax's numpy you will get jax.interpreters.xla.DeviceArray
. These two types map to each other. So if you find some error logs mentioning DeviceArray type, don't worry about it, treat it like you would treat an ndarray and march ahead.
You can get a randomized numpy array by using the numpy.random.random()
function.
This is one of the functionalities that Trax's numpy does not currently support in the same way as the regular numpy.
numpy_array = numpy.random.random((5,10))
print(f"The regular numpy array looks like this:\n\n {numpy_array}\n")
print(f"It is of type: {type(numpy_array)}")
The regular numpy array looks like this: [[0.85888927 0.37271115 0.55512878 0.95565655 0.7366696 0.81620514 0.10108656 0.92848807 0.60910917 0.59655344] [0.09178413 0.34518624 0.66275252 0.44171349 0.55148779 0.70371249 0.58940123 0.04993276 0.56179184 0.76635847] [0.91090833 0.09290995 0.90252139 0.46096041 0.45201847 0.99942549 0.16242374 0.70937058 0.16062408 0.81077677] [0.03514717 0.53488673 0.16650012 0.30841038 0.04506241 0.23857613 0.67483453 0.78238275 0.69520163 0.32895445] [0.49403187 0.52412136 0.29854125 0.46310814 0.98478429 0.50113492 0.39807245 0.72790532 0.86333097 0.02616954]] It is of type: <class 'numpy.ndarray'>
You can easily cast regular numpy arrays or lists into trax numpy arrays using the trax.fastmath.numpy.array()
function:
trax_numpy_array = trax_numpy.array(numpy_array)
print(f"The trax numpy array looks like this:\n\n {trax_numpy_array}\n")
print(f"It is of type: {type(trax_numpy_array)}")
The trax numpy array looks like this: [[0.8588893 0.37271115 0.55512875 0.9556565 0.7366696 0.81620514 0.10108656 0.9284881 0.60910916 0.59655344] [0.09178413 0.34518623 0.6627525 0.44171348 0.5514878 0.70371246 0.58940125 0.04993276 0.56179184 0.7663585 ] [0.91090834 0.09290995 0.9025214 0.46096042 0.45201847 0.9994255 0.16242374 0.7093706 0.16062407 0.81077677] [0.03514718 0.5348867 0.16650012 0.30841038 0.04506241 0.23857613 0.67483455 0.7823827 0.69520164 0.32895446] [0.49403188 0.52412134 0.29854125 0.46310815 0.9847843 0.50113493 0.39807245 0.72790533 0.86333096 0.02616954]] It is of type: <class 'jax.interpreters.xla._DeviceArray'>
The previous section was a quick look at Trax's numpy. However this notebook also aims to teach you how you can calculate the perplexity of a trained model.
Calculating Perplexity
The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as:
\[ P(W) = \sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}} \]
As an implementation hack, you would usually take the log of that formula (to enable us to use the log probabilities we get as output of our RNN
, convert exponents to products, and products into sums which makes computations less complicated and computationally more efficient). You should also take care of the padding, since you do not want to include the padding when calculating the perplexity (because we do not want to have a perplexity measure artificially good). The algebra behind this process is explained next:
We're going to use some pre-made arrays.
predictions = numpy.load(Paths.predictions)
targets = numpy.load(Paths.targets)
Now we'll cast the numpy arrays to jax.interpreters.xla.DeviceArrays.
predictions = trax_numpy.array(predictions)
targets = trax_numpy.array(targets)
print(f'predictions has shape: {predictions.shape}')
print(f'targets has shape: {targets.shape}')
predictions has shape: (32, 64, 256) targets has shape: (32, 64)
Notice that the predictions have an extra dimension - this is the same length as the size of the vocabulary used. Because of this you will need a way of reshaping targets
to match this shape. For this we will use trax.layers.one_hot.
Also note that we can get the size of the last dimension using predictions.shape[-1]
.
reshaped_targets = layers.one_hot(x=targets, n_categories=predictions.shape[-1])
print(f'reshaped_targets has shape: {reshaped_targets.shape}')
reshaped_targets has shape: (32, 64, 256)
By calculating the product of the predictions and the reshaped targets and summing across the last dimension, we can compute the total log perplexity.
total_log_perplexity = trax_numpy.sum(predictions * reshaped_targets, axis= -1)
Now you will need to account for the padding so this metric is not artificially deflated (since a lower perplexity means a better model). To identify which elements are padding and which are not, you can use np.equal()
and get a tensor with True
in the positions of actual values and False
where there are paddings.
equals_zero = trax_numpy.equal(targets, 0)
print(equals_zero)
[[False False False ... True True True] [False False False ... True True True] [False False False ... True True True] ... [False False False ... True True True] [False False False ... True True True] [False False False ... True True True]]
equals_zero
is a boolean array that has True
wherever the cell had a 0 and False
everywhere else. To make it numeric we can subtract the boolean array from 1 (generally in python True is treated as 1 and False as 0).
non_pad = 1.0 - equals_zero
print(f'non_pad has shape: {non_pad.shape}\n')
print(f'non_pad looks like this: \n\n {non_pad}')
non_pad has shape: (32, 64) non_pad looks like this: [[1. 1. 1. ... 0. 0. 0.] [1. 1. 1. ... 0. 0. 0.] [1. 1. 1. ... 0. 0. 0.] ... [1. 1. 1. ... 0. 0. 0.] [1. 1. 1. ... 0. 0. 0.] [1. 1. 1. ... 0. 0. 0.]]
Now if we multiply total_log_perplexity
by the non_pad
we'll zero-out all the entries in total_log_perplexity
where non_pad
has zero.
real_log_perplexity = total_log_perplexity * non_pad
print(f'real perplexity still has shape: {real_log_perplexity.shape}')
real perplexity still has shape: (32, 64)
We can check the effect of filtering out the padding by looking at the two log perplexity tensors.
print(f'log perplexity tensor before filtering padding: \n\n {total_log_perplexity}\n')
print(f'log perplexity tensor after filtering padding: \n\n {real_log_perplexity}')
log perplexity tensor before filtering padding: [[ -5.396545 -1.0311184 -0.66916656 ... -22.37673 -23.18771 -21.843483 ] [ -4.5857706 -1.1341286 -8.538033 ... -20.15686 -26.837097 -23.57502 ] [ -5.2223887 -1.2824144 -0.17312431 ... -21.328228 -19.854412 -33.88444 ] ... [ -5.396545 -17.291681 -4.360766 ... -20.825802 -21.065838 -22.443115 ] [ -5.9313164 -14.247417 -0.2637329 ... -26.743248 -18.38433 -22.355278 ] [ -5.670536 -0.10595131 0. ... -23.332523 -28.087376 -23.878807 ]] log perplexity tensor after filtering padding: [[ -5.396545 -1.0311184 -0.66916656 ... -0. -0. -0. ] [ -4.5857706 -1.1341286 -8.538033 ... -0. -0. -0. ] [ -5.2223887 -1.2824144 -0.17312431 ... -0. -0. -0. ] ... [ -5.396545 -17.291681 -4.360766 ... -0. -0. -0. ] [ -5.9313164 -14.247417 -0.2637329 ... -0. -0. -0. ] [ -5.670536 -0.10595131 0. ... -0. -0. -0. ]]
To get a single average log perplexity across all the elements in the batch you can sum across both dimensions and divide by the number of elements. Note that the result will be the negative of the real log perplexity of the model.
log_perplexity = -trax_numpy.sum(real_log_perplexity) / trax_numpy.sum(non_pad)
print(f"log perplexity: {log_perplexity:0.4f}, "
f"perplexity: {trax_numpy.exp(log_perplexity):0.4f}")
log perplexity: 2.3281, perplexity: 10.2586