How to train a Keras model to recognize text with variable length

(Comments)

text

I have played with the Keras official image_ocr.py example for a while and want to share my takeaways in this post.

The official example only does the training for the model while missing the prediction part, and my final source code is available both on my GitHub as well as a runnable Google Colab notebook. More technical detail of OCR(optical character recognization) including the model structure and CTC loss will also be explained briefly in the following sections.

OCR task declaration 

The input will be an image contains a single line of text, the text could be at any location in the image. And the task for the model is to output the actual text given this image.

For example,

ocr-task

The official image_ocr.py example source code is quite long and may look daunting. It can be breaking down into several parts.

  • The generator for the training samples, this part of the source code will generate vivid text images resembling the scanning documents with artificial speckles, random locations and a variety of fronts.
  • The model callback to save the model weights and visualize the performance of the current model with some generated text images after each training epochs.
  • The model construction and training part. We will elaborate more on this part in the next section.

Model structure

The model input is image data, and we first feed the data to two convolutional networks to extract the image features, followed by the Reshape and Dense to reduce the dimensions of the feature vectors before letting the bidirectional GRU process the sequential data. The sequential data feed to the GRU is the horizontally divided image features. The final output Dense layer transforms the output for a given image to an array with the shape of (32, 28) representing (#of horizontal steps, #char labels).

base-model

And here is the part of the code to construct the Keras model.

img_w = 128
# Input Parameters
img_h = 64
# Network parameters
conv_filters = 16
kernel_size = (3, 3)
pool_size = 2
time_dense_size = 32
rnn_size = 512
minibatch_size = 32
unique_tokens = 28

if K.image_data_format() == 'channels_first':
    input_shape = (1, img_w, img_h)
else:
    input_shape = (img_w, img_h, 1)

act = 'relu'
input_data = Input(name='the_input', shape=input_shape, dtype='float32')
inner = Conv2D(conv_filters, kernel_size, padding='same',
               activation=act, kernel_initializer='he_normal',
               name='conv1')(input_data)
inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)
inner = Conv2D(conv_filters, kernel_size, padding='same',
               activation=act, kernel_initializer='he_normal',
               name='conv2')(inner)
inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)

conv_to_rnn_dims = (img_w // (pool_size ** 2), (img_h // (pool_size ** 2)) * conv_filters)
inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)

# cuts down input size going into RNN:
inner = Dense(time_dense_size, activation=act, name='dense1')(inner)

# Two layers of bidirectional GRUs
# GRU seems to work as well, if not better than LSTM:
gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru1')(inner)
gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(inner)
gru1_merged = add([gru_1, gru_1b])
gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)
gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(gru1_merged)

# transforms RNN output to character activations:
inner = Dense(unique_tokens, kernel_initializer='he_normal',
              name='dense2')(concatenate([gru_2, gru_2b]))
y_pred = Activation('softmax', name='softmax')(inner)
Model(inputs=input_data, outputs=y_pred)

CTC Loss

As we can see in the example image, the text could be located anywhere, how the model align between the input and output to locates each character in the image and turns them into text? That is where CTC comes into play, CTC stands for connectionist temporal classification.

input-output

Notice that the output of the model has 32 timesteps, but the output might not have 32 characters. The CTC cost function allows the RNN to generate output like:

ctc-out

CTC introduced the "blank" token, and itself doesn't translate into any character, what it does is to separate individual characters so that we can collapse repeated characters that are not separated by the blank.

So the decoding output for the previous sequence will be "a game".

Let's take a look at another example of the text "speed".

ctc-out-speed

According to for the decoding principle, we first collapse repeating characters that are not separated by blank token, and then we remove the blank tokens themselves. Notice that if there is no blank token to separate the two "e"s they will be collapsed into one.

In Keras, the CTC decoding can be performed in a single function, K.ctc_decode.

from keras import backend as K

K.get_value(K.ctc_decode(out, input_length=np.ones(out.shape[0])*out.shape[1],
                         greedy=True)[0][0])

The out is the model output which consists of 32 timesteps of 28 softmax probability values for each of the 28 tokens from a~z, space, and blank token. We set the parameter greedy to perform the greedy search which means the function will only return the most likely output token sequence.

Alternatively, if we want to have the CTC decoder return the top N possible output sequence, we can ask it to perform beam search with a given beam width.

top_paths = 3
results = []
for i in range(top_paths):
  lables = K.get_value(K.ctc_decode(out, input_length=np.ones(out.shape[0])*out.shape[1],
                       greedy=False, beam_width=top_paths, top_paths=top_paths)[0][i])[0]
  results.append(lables)

One thing worth mentioning is that if you are new to beam search algorithm, the top_paths parameter is no greater than the beam_width parameter since the beam width tells the beam search algorithm exactly how many top results to keep track of in iterating all timesteps.

Right now the output of the decoder will be a sequence of tokens, and we just need to translate the numerical classes back to characters.

So far we only talked about the decoding part of the CTC. You may wonder how the model is trained with CTC loss?

In order to compute the CTC loss, it requires more than the true labels and predicted outputs of the model, but also the output sequence length and the lengths for each of the true labels.

  • y_true. A sample of it could look like [0, 1, 2, 3, 4, 26, 25] stands for the text sequence ‘abcde z’
  • y_pred is the output of the softmax layer, a sample of it has the shape (32, 28), 32 timesteps, 28 categories, i.e. ‘a-z’, space and blank token.
  • input_length is the output sequence length img_w // downsample_factor – 2 = 128 / 4 -2 = 30, 2 means the first 2 discarded RNN output timesteps since first couple outputs of the RNN tend to be garbage.
  • label_length will be 7 for the previous y_true sample,

In Keras the CTC loss is packaged in one function K.ctc_batch_cost.

# the actual loss calc occurs here despite it not being
# an internal Keras loss function

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    # the 2 is critical here since the first couple outputs of the RNN
    # tend to be garbage:
    y_pred = y_pred[:, 2:, :]
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

Conclusion

Checkpoint results after training the model for 25 epochs.

e24

If you have read this far and experimenting along on the Google Colab you should now have a Keras OCR demo running. If you are still eager for more information about CTC and beam search, feel free to check out the following resources.

Sequence Modeling With CTC - An in-depth elaboration of CTC algorithm and other applications where CTC can be applied to such as speech recognition, lip reading from video and so on.

Coursera Beam search video lecture. Quick and easy to understand.

Don't forget to get the source code from my GitHub as well as a runnable Google Colab notebook. 

Current rating: 5

Comments