Solving the Vanishing Gradient Problem with LSTMs
Renduchinthala Sai praneeth Kumar

Solving the Vanishing Gradient Problem with LSTMs

When diving into the theory behind Recurrent Neural Networks (RNNs) and Long Short-Term Memory (LSTM) networks, two main questions arise:

1. Why do RNNs suffer from vanishing and exploding gradients?

2. How do LSTMs keep the gradients from vanishing or explode?

When I tried answering these questions, I searched for a mathematical explanation to get a better understanding of how these networks work. I had a hard time finding proofs that were understandable and clear enough for me. After reading the recommended papers and known blogs dealing with these questions I wrote an explanation that worked for me and made me feel I better understand the problem and solution.

RNNs and vanishing gradients

However, RNNs suffer from the problem of vanishing gradients, which hampers learning of long data sequences. The gradients carry information used in the RNN parameter update and when the gradient becomes smaller and smaller, the parameter updates become insignificant which means no real learning is done.

Let’s have a short reminder of how RNNs look like. We will work with a simple single hidden layer RNN with a single output sequence. The network looks like this:

The network has an input sequence of vectors [x(1), x(2),…, x(k)], at time step t the network has an input vector x(t). Past information and learned knowledge is encoded in the network state vectors [c(1), c(2),…, c(k-1)], at time step t the network has an input state vector c(t-1). The input vector x(t) and the state vector c(t-1) are concatenated to comprise the complete input vector at time step t, [c(t-1), x(t)] .

The network has two weight matrices: Wrec and Win connecting c(t-1) and x(t), the two parts of the input vector [c(t-1), x(t)], to the hidden layer. For simplicity, we leave out the bias vectors in our computations, and denote W = [Wrec, Win].

The sigmoid function is used as the activation function in the hidden layer.

The network outputs a single vector at the last time step (RNNs can output a vector on each time step, but we’ll use this simpler model).

Backpropagation through time (BPTT) in RNNs

After the RNN outputs the prediction vector h(k), we compute the prediction error E(k) and use the Back Propagation Through time algorithm to compute the gradient


The gradient is used to update the model parameters by:

And we continue the learning process using the Gradient Descent (GD) algorithm (we use the basic version of the GD in this work).

Say we have learning task that includes T time steps, the gradient of the error on the k time step is given by:

Notice that since W=[Wrec, Win], c(t) can be written as:

Compute the derivative of c(t) and get:

Plug (2) into (1) and get our backpropagated gradient

The last expression tends to vanish when k is large, this is due to the derivative of the tanh activation function which is smaller than 1.

The product of derivatives can also explode if the weights Wrec are large enough to overpower the smaller tanh derivative, this is known as the exploding gradient problem.

We have:

So for some time step k:

And our complete error gradient will vanish

The network’s weights update will be:

And no significant learning will be done in reasonable time.

No alt text provided for this image


How LSTMs solve this?

An LSTM network has an input vector [h(t-1),x(t)] at time step t. The network cell state is denoted by c(t). The output vectors passed through the network between consecutive time steps t, t+1 are denoted by h(t).

an LSTM network has three gates that update and control the cell states, these are the forget gate, input gate and output gate. The gates use hyperbolic tangent and sigmoid activation functions.

The forget gate controls what information in the cell state to forget, given new information than entered the network.

The forget gate’s output is given by:

The input gate controls what new information will be encoded into the cell state, given the new input information.

The input gate’s output has the form:

and is equal to the element-wise product of the outputs of the two fully connected layers:

The output gate controls what information encoded in the cell state is sent to the network as input in the following time step, this is done via the output vector h(t).

The output gate’s activations are given by:

and the cell’s output vector is given by:

No alt text provided for this image


The LSTM cell state

The long term dependencies and relations are encoded in the cell state vectors and it’s the cell state derivative that can prevent the LSTM gradients from vanishing. The LSTM cell state has the form:


Backpropagation through time in LSTMs

As in the RNN model, our LSTM network outputs a prediction vector h(k) on the k-th time step. The knowledge encoded in the state vectors c(t) captures long-term dependencies and relations in the sequential data.


The length of the data sequences can be hundreds and even thousands of time steps, making it extremely difficult to learn using a basic RNN.

We compute the gradient used to update the network parameters, the computation is done over T time steps.

As in RNNs, the error term gradient is given by the following sum of T gradients:

For the complete error gradient to vanish, all of these T sub gradients need to vanish. If we think of (3) as a series of functions, then by definition, this series converges to zero if the sequence of its partial sums tends to zero, so

if the series of partial sums

where

tends to zero.

So if we want (3) not to vanish, our network needs to increase the likelihood that at least some of these sub gradients will not vanish, in other words, make the series of sub gradients in (3) not converge to zero.

The error gradients in an LSTM network

The gradient of the error for some time step k has the form:


As we have seen, the following product causes the gradients to vanish:

In an LTSM, the state vector c(t), has the form:

which can be written compactly as

Notice that the state vector c(t) is a function of the following elements, which should be taken into account when computing the derivative during backpropagation:

Compute the derivative of (5) and get:

We compute (detailed computations are given in the end of the article) the four derivative terms and write:

Denote the four elements comprising the derivative of the cell state by:

We write the additive gradient as:

Plug (6) into (4) and get the LSTM states gradient:

Preventing the error gradients from vanishing

Notice that the gradient contains the forget gate’s vector of activations, which allows the network to better control the gradients values, at each time step, using suitable parameter updates of the forget gate. The presence of the forget gate’s activations allows the LSTM to decide, at each time step, that certain information should not be forgotten and to update the model’s parameters accordingly.


Let’s go over how this property helps us. Say that for some time step k<T, we have that:

Then for the gradient not to vanish, we can find a suitable parameter update of the forget gate at time step k+1 such that:

It is the presence of the forget gate’s vector of activations in the gradient term along with additive structure which allows the LSTM to find such a parameter update at any time step, and this yields:

and the gradient doesn’t vanish.

Another important property to notice is that the cell state gradient is an additive function made up from four elements denoted A(t), B(t), C(t), D(t). This additive property enables better balancing of gradient values during backpropagation. The LSTM updates and balances the values of the four components making it more likely the additive expression does not vanish.

For example, say that for every t in {2,3,…,k} we take the following four neighbourhoods of values as a balancing combination in our gradient:

which yields:

and the product does not vanish.

This additive property is different from the RNN case where the gradient contained a single element inside the product. In RNNs, the sum in (3) is made from expressions with a similar behaviour that are likely to all be in [0,1] which causes vanishing gradients.

In LSTMs, however, the presence of the forget gate, along with the additive property of the cell state gradients, enables the network to update the parameter in such a way that the different sub gradients in (3) do not necessarily agree and behave in a similar manner, making it less likely that all of the T gradients in (3) will vanish, or in other words, the series of functions does not converge to zero:

and our gradients do not vanish.

As mentioned briefly, the RNN gradients can also explode if the sum in (3) is made up from expressions with a similar behaviour that are all significantly greater than 1.

Summing up, we have seen that RNNs suffer from vanishing gradients and caused by long series of multiplications of small values, diminishing the gradients and causing the learning process to become degenerate. In a analogues way, RNNs suffer from exploding gradients affected from large gradient values and hampering the learning process.

LSTMs solve the problem using a unique additive gradient structure that includes direct access to the forget gate’s activations, enabling the network to encourage desired behaviour from the error gradient using frequent gates update on every time step of the learning process.

To view or add a comment, sign in

More articles by Sai Praneeth Kumar Renduchinthala

Others also viewed

Explore content categories