Sentiment Analysis: Deep Learning Model

Beginning

Previously we created sentiment analysis models using the Logistic Regression and Naive Bayes algorithms. However if we were to give those models an example like:

This movie was almost good.

The model would have predicted a positive sentiment for that review. That sentence, however, is expressing the negative sentiment that the movie was not good. To solve those kinds of misclassifications we will write a program that uses deep neural networks to identify sentiment in text.

This model will follow a similar structure to the Continuous Bag of Words Model (Introducing the CBOW Model) that we looked at previously - indeed most of the deep nets have a similar structure. The only thing that changes is the model architecture, the inputs, and the outputs. Although we looked at Trax and JAX in a previous post (Introducing Trax) we'll start off with a review of some of their features and then in future posts we'll implement the actual model. These are the other posts.

Imports

# from python
import os
import random

# from pypi
from trax import layers
import trax
import trax.fastmath.numpy as numpy

Set Up

The Random Seed

trax.supervised.trainer_lib.init_random_number_generators(31)

Middle

Trax Review

JAX Arrays

First, the JAX reimplementation of numpy (from Trax.fastmath).

an_array = numpy.array(5.0)
display(an_array)
print(type(an_array))
DeviceArray(5., dtype=float32)
<class 'jax.interpreters.xla._DeviceArray'>

Note: the trax library is strict about the typing so 5 won't work, it has to be a float.

Squaring

Now we'll create a function to square the array.

def square(x) :
    return x**2
print(f"f({an_array}) -> {square(an_array)}")
f(5.0) -> 25.0

Gradients

The gradient (derivative) of function f with respect to its input x is the derivative of \(x^2\).

  • The derivative of \(x^2\) is \(2x\).
  • When x is 5, then 2x=10.

You can calculate the gradient of a function by using trax.fastmath.grad(fun=) and passing in the name of the function.

  • In this case the function you want to take the gradient of is square.
  • The object returned (saved in square_gradient in this example) is a function that can calculate the gradient of square for a given trax.fastmath.numpy array.

Use trax.fastmath.grad to calculate the gradient (derivative) of the function.

square_gradient = trax.fastmath.grad(fun=square)

print(type(square_gradient))
<class 'function'>

Call the newly created function and pass in a value for x (the DeviceArray stored in 'a')

gradient_calculation = square_gradient(an_array)
display(gradient_calculation)
DeviceArray(10., dtype=float32)

The function returned by trax.fastmath.grad takes in x=5 and calculates the gradient of square, which is 2x, which equals 10. The value is also stored as a DeviceArray from the jax library.

End

Now that we've had a brief review of Trax let's move on to loading the data.