Sentiment Analysis: Deep Learning Model
Table of Contents
Beginning
Previously we created sentiment analysis models using the Logistic Regression and Naive Bayes algorithms. However if we were to give those models an example like:
This movie was almost good.
The model would have predicted a positive sentiment for that review. That sentence, however, is expressing the negative sentiment that the movie was not good. To solve those kinds of misclassifications we will write a program that uses deep neural networks to identify sentiment in text.
This model will follow a similar structure to the Continuous Bag of Words Model (Introducing the CBOW Model) that we looked at previously - indeed most of the deep nets have a similar structure. The only thing that changes is the model architecture, the inputs, and the outputs. Although we looked at Trax and JAX in a previous post (Introducing Trax) we'll start off with a review of some of their features and then in future posts we'll implement the actual model. These are the other posts.
Imports
# from python
import os
import random
# from pypi
from trax import layers
import trax
import trax.fastmath.numpy as numpy
Set Up
The Random Seed
trax.supervised.trainer_lib.init_random_number_generators(31)
Middle
Trax Review
JAX Arrays
First, the JAX reimplementation of numpy (from Trax.fastmath).
an_array = numpy.array(5.0)
display(an_array)
print(type(an_array))
DeviceArray(5., dtype=float32) <class 'jax.interpreters.xla._DeviceArray'>
Note: the trax library is strict about the typing so 5
won't work, it has to be a float.
Squaring
Now we'll create a function to square the array.
def square(x) :
return x**2
print(f"f({an_array}) -> {square(an_array)}")
f(5.0) -> 25.0
Gradients
The gradient (derivative) of function f
with respect to its input x
is the derivative of \(x^2\).
- The derivative of \(x^2\) is \(2x\).
- When x is 5, then 2x=10.
You can calculate the gradient of a function by using trax.fastmath.grad(fun=)
and passing in the name of the function.
- In this case the function you want to take the gradient of is
square
. - The object returned (saved in
square_gradient
in this example) is a function that can calculate the gradient ofsquare
for a giventrax.fastmath.numpy
array.
Use trax.fastmath.grad
to calculate the gradient (derivative) of the function.
square_gradient = trax.fastmath.grad(fun=square)
print(type(square_gradient))
<class 'function'>
Call the newly created function and pass in a value for x (the DeviceArray stored in 'a')
gradient_calculation = square_gradient(an_array)
display(gradient_calculation)
DeviceArray(10., dtype=float32)
The function returned by trax.fastmath.grad
takes in x=5 and calculates the gradient of square
, which is 2x, which equals 10. The value is also stored as a DeviceArray from the jax library.
End
Now that we've had a brief review of Trax let's move on to loading the data.