RNNS and Vanishing Gradients

Vanishing Gradients

This will be a look at the problem of vanishing gradients from an intuitive standpoint.

Background

Adding layers to a neural network introduces multiplicative effects in both forward and backward propagation. The back-prop in particular presents a problem as the gradient of activation functions can be very small. Multiplied together across many layers, their product can be vanishingly small. This results in weights not being updated in the front layers and training not progressing.

Gradients of the sigmoid function, for example, are in the range 0 to 0.25. To calculate gradients for the front layers of a neural network the chain rule is used. This means that these tiny values are multiplied starting at the last layer, working backwards to the first layer, with the gradients shrinking exponentially at each step.

Imports

# python
from collections import namedtuple
from functools import partial

# pypi
import holoviews
import hvplot.pandas
import numpy
import pandas

# another project
from graeae import EmbedHoloviews

Set Up

SLUG = "rnns-and-vanishing-gradients"
Embed = partial(EmbedHoloviews,
                folder_path=f"files/posts/nlp/{SLUG}")
Plot = namedtuple("Plot", ["width", "height", "fontscale", "tan", "blue", "red"])
PLOT = Plot(
    width=900,
    height=750,
    fontscale=2,
    tan="#ddb377",
    blue="#4687b7",
    red="#ce7b6d",
 )

Middle

The Data

This will be an evenly spaced set of points over an interval (see numpy.linspace).

STOP, STEPS = 10, 100
x = numpy.linspace(-STOP, STOP, STEPS)

The Sigmoid

Our activation function will be the sigmoid (wikipedia link) (well, the logistic function).

def sigmoid(x: numpy.ndarray) -> numpy.ndarray:
    return 1 / (1 + numpy.exp(-x))

Now we'll calculate the activations for our input data.

activations = sigmoid(x)

The Gradient

Our gradient is the derivative of the sigmoid.

def gradient(x: numpy.ndarray) -> numpy.ndarray:
    return (x) * (1 - x)

Now we can get the gradients for our activations.

gradients = gradient(activations)

Plotting the Sigmoid

tangent_x = 0
tangent_y = sigmoid(tangent_x)
span = 2

gradient_tangent = gradient(sigmoid(tangent_x))

tangent_plot_x = numpy.linspace(tangent_x - span, tangent_x + span, STEPS)
tangent_plot_y = tangent_y + gradient_tangent * (tangent_plot_x - tangent_x)

frame = pandas.DataFrame.from_dict(
    {"X": x,
     "Sigmoid": activations,
     "X-Tangent": tangent_plot_x,
     "Y-Tangent": tangent_plot_y,
     "Gradient": gradients})
plot = (frame.hvplot(x="X", y="Sigmoid").opts(color=PLOT.blue)
        * frame.hvplot(x="X", y="Gradient").opts(color=PLOT.red)
        * frame.hvplot(x="X-Tangent",
                       y="Y-Tangent").opts(color=PLOT.tan)).opts(
            title="Sigmoid and Tangent",
            width=PLOT.width,
            height=PLOT.height,
            fontscale=PLOT.fontscale)
output = Embed(plot=plot, file_name="sigmoid_tangent")()
print(output)

Figure Missing

The thing to notice is that as the input data moves away from the center (at 0) the gradients get smaller in either direction, rapidly approaching zero.

The Numerical Impact

Multiplication & Decay

Multiplying numbers smaller than 1 results in smaller and smaller numbers. Below is an example that finds the gradient for an input x = 0 and multiplies it over n steps. Look how quickly it 'Vanishes' to almost zero. Yet \(\sigma(x=0) \implies 0.5\) which has a sigmoid gradient of 0.25 and that happens to be the largest sigmoid gradient possible.

A Decay Simulation

Input data

n = 6
x = 0

gradients = gradient(sigmoid(x))
steps = numpy.arange(1, n + 1)
print("-- Inputs --")
print("steps :", n)
print("x value :", x)
print("sigmoid :", "{:.5f}".format(sigmoid(x)))
print("gradient :", "{:.5f}".format(gradients), "\n")
-- Inputs --
steps : 6
x value : 0
sigmoid : 0.50000
gradient : 0.25000 

Plot The Decay

decaying_values = (numpy.ones(len(steps)) * gradients).cumprod()
data = pandas.DataFrame.from_dict(dict(Step=steps, Gradient=decaying_values))
plot = data.hvplot(x="Step", y="Gradient").opts(
    title="Cumulative Gradient",
    width=PLOT.width,
    height=PLOT.height,
    fontscale=PLOT.fontscale
)
output = Embed(plot=plot, file_name="cumulative_gradient")()
print(output)

Figure Missing

The point being that the gradients very quickly approach zero.

So, How Do You Fix This?

One solution is to use activation functions that don't have tiny gradients. Other solutions involve more sophisticated model design. But they're both discussions for another time.