NER: Evaluating the Model
Table of Contents
Beginning
Now we'll evaluate our model using the test set. To do this we'll need to create a mask to avoid counting the padding tokens when computing the accuracy.
- Step 1: Calling
model(sentences)
will give us the predicted output. - Step 2: The output will be the prediction with an added dimension. For each word in each sentence there will be a vector of probabilities for each tag type. For each word in each sentence we'll need to pick the maximum valued tag. This will require
np.argmax
and careful use of theaxis
argument. - Step 3: Create a mask to prevent counting pad characters. It will have the same dimensions as the output.
- Step 4: Compute the accuracy metric by comparing the outputs against the test labels. Take the sum of that and divide by the total number of unpadded tokens. Use the mask value to mask the padded tokens.
Imports
# python
from collections import namedtuple
from pathlib import Path
# pypi
import holoviews
import hvplot.pandas
import jax
import numpy
import pandas
import trax
# this project
from neurotic.nlp.named_entity_recognition import (DataGenerator,
NER,
NERData,
TOKEN)
# another project
from graeae import EmbedHoloviews
Set Up
Plotting
slug = "ner-evaluating-the-model"
Embed = partial(EmbedHoloviews, folder_path=f"files/posts/nlp/{slug}")
Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
width=900,
height=750,
fontscale=2,
tan="#ddb377",
blue="#4687b7",
red="#ce7b6d",
)
The Previous Code
data = NERData().data
model = NER(vocabulary_size=len(data.vocabulary),
tag_count=len(data.tags)).model
Settings = namedtuple("Settings", ["batch_size", "padding_id", "seed"])
SETTINGS = Settings(batch_size=64,
padding_id=data.vocabulary[TOKEN.pad],
seed=33)
model.init_from_file(Path("~/models/ner/model.pkl.gz").expanduser())
print(model)
random.seed(SETTINGS.seed)
test_generator = DataGenerator(x=ner.data.data_sets.x_test,
y=data.data_sets.y_test,
batch_size=SETTINGS.batch_size,
padding=SETTINGS.padding_id)
Serial[ Embedding_35180_50 LSTM_50 Dense_18 LogSoftmax ]
Middle
As a reminder, here's what happens when you apply a boolean comparison to a numpy array.
a = numpy.array([1, 2, 3, 4])
print(a == 2)
[False True False False]
A Test Input
x, y = next(test_generator)
print(f"x's shape: {x.shape} y's shape: {y.shape}")
predictions = model(x)
print(type(predictions))
print(f"predictions has shape: {predictions.shape}")
x's shape: (64, 44) y's shape: (64, 44) <class 'jax.interpreters.xla._DeviceArray'> predictions has shape: (64, 44, 18)
Note: the model's prediction has 3 axes:
- the number of examples
- the number of words in each example (padded to be as long as the longest sentence in the batch)
- the number of possible targets (the 17 named entity tags).
def evaluate_prediction(pred: jax.interpreters.xla._DeviceArray,
labels: numpy.ndarray,
pad: int=SETTINGS.padding_id) -> float:
"""Calculates the accuracy of a prediction
Args:
pred: prediction array with shape
(num examples, max sentence length in batch, num of classes)
labels: array of size (batch_size, seq_len)
pad: integer representing pad character
Returns:
accuracy: fraction of correct predictions
"""
outputs = numpy.argmax(pred, axis=-1)
mask = labels != pad
return numpy.sum((outputs==labels)[mask])/numpy.sum(mask)
accuracy = evaluate_prediction(model(x), y)
print("accuracy: ", accuracy)
accuracy: 0.9636752
Hmm, does pretty good.
Plotting
Let's look at running more batches. It occurred to me that you could also just do the whole set at once, I don't know what's special about using the batches.
repetitions = range(
int(len(data.data_sets.x_test)/SETTINGS.batch_size))
nexts = (next(test_generator) for repetition in repetitions)
accuracies = [evaluate_prediction(model(x), y) for x, y in nexts]
data = pandas.DataFrame.from_dict(dict(Accuracy=accuracies))
plot = data.Accuracy.hvplot(kind="hist", color=PLOT.tan).opts(
title="Accuracy Distribution",
height=PLOT.height,
width=PLOT.width,
fontscale=PLOT.fontscale)
output = Embed(plot=plot, file_name="accuracy_distribution")()
print(output)