How to use return_state or return_sequences in Keras

(Comments)

moon_cycle

You may have noticed in several Keras recurrent layers, there are two parametersreturn_state ,and return_sequences. In this post, I am going to show you what they mean and when to use them in real-life cases.

To understand what they mean, we need firstly crack open a recurrent layer a little bit such as the most often used LSTM and GRU.

RNN in a nutshell

The most primitive version of the recurrent layer implemented in Keras, the SimpleRNN, which is suffered from the vanishing gradients problem causing it challenging to capture long-range dependencies. Alternatively, LSTM and GRU each are equipped with unique "Gates" to avoid the long-term information from "vanishing" away.

rnn

In the graph above we can see given an input sequence to an RNN layer, each RNN cell related to each time step will generate output known as the hidden state, a<t>.

Depends on which RNN you use, it differs in how a<t> is computed.

gru-lstm

c<t> for each RNN cell in the above formulas is known as the cell state. For GRU, a given time step's cell state equals to its output hidden state. For LSTM, the output hidden state a<t> is produced by "gating" cell state c<t> by the output gate Γo, so a<t> and c<t> are not the same. Don't worry about the rest formulas. The basic understanding of RNN should be enough for the tutorial.

Return sequences

Return sequences refer to return the hidden state a<t>. By default, the return_sequences is set to False in Keras RNN layers, and this means the RNN layer will only return the last hidden state output a<T>. The last hidden state output captures an abstract representation of the input sequence. In some case, it is all we need, such as a classification or regression model where the RNN is followed by the Dense layer(s) to generate logits for news topic classification or score for sentiment analysis, or in a generative model to produce the softmax probabilities for the next possible char.

In other cases, we need the full sequence as the output. Setting return_sequences to True is necessary.

Let's define a Keras model consists of only an LSTM layer. Use constant initializers so that the output results are reproducible for the demo purpose.

from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM
from numpy import array
import keras
k_init = keras.initializers.Constant(value=0.1)
b_init = keras.initializers.Constant(value=0)
r_init = keras.initializers.Constant(value=0.1)
# LSTM units
units = 1

# define model
inputs1 = Input(shape=(3, 2))
lstm1 = LSTM(units, return_sequences=True, kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init)(inputs1)
model = Model(inputs=inputs1, outputs=lstm1)
# define input data
data = array([0.1, 0.2, 0.3, 0.1, 0.2, 0.3]).reshape((1,3,2))
# make and show prediction
output = model.predict(data)
print(output, output.shape)

output:

[[[0.00767819]
  [0.01597687]
  [0.02480672]]] (1, 3, 1)

We can see the output array's shape of the LSTM layer is (1,3,1) which stands for (#Samples, #Time steps, #LSTM units). Compared to when return_sequences is set to False, the shape will be (#Samples, #LSTM units), which only returns the last time step hidden state.

# define model
inputs1 = Input(shape=(3, 2))
lstm1 = LSTM(units, kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init)(inputs1)
model = Model(inputs=inputs1, outputs=lstm1)
# define input data
data = array([0.1, 0.2, 0.3, 0.1, 0.2, 0.3]).reshape((1,3,2))
# make and show prediction
preds = model.predict(data)
print(preds, preds.shape)

output:

[[0.02480672]] (1, 1)

There are two primary situations when you can apply the return_sequences to return the full sequence.

  1. Stacking RNN, the former RNN layer or layers should set return_sequences to True so that the following RNN layer or layers can have the full sequence as input.
  2. We want to generate classification for each time step.
    1. Such as speech recognition or much simpler form - trigger word detection where we generate a value between 0~1 for each timestep representing whether the trigger word is present.
    2. OCR(Optical character recognition) sequence modeling with CTC.

Return states

Return sequences refer to return the cell state c<t>. For GRU, as we discussed in "RNN in a nutshell" section, a<t>=c<t>, so you can get around without this parameter. But for LSTM, hidden state and cell state are not the same.

In Keras we can output RNN's last cell state in addition to its hidden states by setting return_state to True.

# define model
inputs1 = Input(shape=(3, 2))
lstm1, state_h, state_c = LSTM(units, return_state=True, kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init)(inputs1)
model = Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
# define input data
data = array([0.1, 0.2, 0.3, 0.1, 0.2, 0.3]).reshape((1,3,2))

# make and show prediction
output = model.predict(data)
print(output)
for a in output:
    print(a.shape) 

Output:

[array([[0.02480672]], dtype=float32), array([[0.02480672]], dtype=float32), array([[0.04864851]], dtype=float32)]
(1, 1)
(1, 1)
(1, 1)

The output of the LSTM layer has three components, they are (a<T>, a<T>, c<T>), "T" stands for the last timestep, each one has the shape (#Samples, #LSTM units).

The major reason you want to set the return_state is an RNN may need to have its cell state initialized with previous time step while the weights are shared, such as in an encoder-decoder model. A snippet of the code from an encoder-decoder model is shown below.

decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
for _ in range(max_decoder_seq_length):
    # Run the decoder on one timestep
    outputs, state_h, state_c = decoder_lstm(inputs,
                                             initial_state=states)
    outputs = decoder_dense(outputs)
    # Store the current prediction (we will concatenate all predictions later)
    all_outputs.append(outputs)
    # Reinject the outputs as inputs for the next loop iteration
    # as well as update the states
    inputs = outputs
    states = [state_h, state_c]

You have noticed for the above encoder-decoder model both return_sequences and return_state are set to True. In that case, the output of the LSTM will have three components, (a<1...T>, a<T>, c<T>). If we do the same from our previous examples we can better understand its difference.

# define model
inputs1 = Input(shape=(3, 2))
lstm1, state_h, state_c = LSTM(units, return_sequences=True, return_state=True, kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init)(inputs1)
model = Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
# define input data
data = array([0.1, 0.2, 0.3, 0.1, 0.2, 0.3]).reshape((1,3,2))
# make and show prediction
output = model.predict(data)
print(output)
for a in output:
    print(a.shape) 

Output

[array([[[0.00767819],
        [0.01597687],
        [0.02480672]]], dtype=float32), array([[0.02480672]], dtype=float32), array([[0.04864851]], dtype=float32)]
(1, 3, 1)
(1, 1)
(1, 1)

One thing worth mentioning is that if we replace LSTM with GRU the output will have only two components. (a<1...T>, c<T>) since in GRU a<T>=c<T>.

Conclusion

To understand how to use return_sequences and return_state, we start off with a short introduction of two commonly used recurrent layers, LSTM and GRU and how their cell state and hidden state are derived. Next, we dived into some cases of applying each of two arguments as well as tips when you can consider using them in your next model.

You can find the source code for this post on my GitHub repo.

Currently unrated

Comments