Bag of Tricks for Image Classification with Convolutional Neural Networks in Keras

(Comments)

tricks

This tutorial shows you how to implement some tricks for image classification task in Keras API as illustrated in paper https://arxiv.org/abs/1812.01187v2. Those tricks work on various CNN models like ResNet-50, Inception-V3, and MobileNet.

Large-batch training

For the same number of epochs, training with a larger batch size results in a model with degraded validation accuracy compared to the ones trained with smaller batch sizes. Four heuristics that help minimize the downside of large batch training, improve accuracy and training speed.

Linearly scaling learning rate

Linearly increasing the learning rate with the batch size

e.g.

Batch size

Learning rate

256

0.1

256 * 2 = 512

0.1 * 2 = 0.2

In Keras API, you can scale the learning rate along with the batch size like this.

from tensorflow.keras import optimizers
base_batch_size = 256
base_lr = 0.1
multiplier = 2
batch_size = base_batch_size * multiplier
lr = base_lr * multiplier
# Create the model
# ...
# Compile and train the model.
sgd = optimizers.SGD(lr=lr)
model.compile(loss='mean_squared_error', optimizer=sgd)
model.fit(x, y, batch_size=batch_size)

Learning rate warmup

Using too large learning rate may result in numerical instability especially at the very beginning of the training, where parameters are randomly initialized. The warmup strategy increases the learning rate from 0 to the initial learning rate linearly during the initial N epochs or m batches.

Even though Keras came with the LearningRateScheduler capable of updating the learning rate for each training epoch, to achieve finer updates for each batch, here is how you can implement a custom Keras callback to do that.

import numpy as np
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import backend as K


class WarmUpLearningRateScheduler(keras.callbacks.Callback):
    """Warmup learning rate scheduler
    """

    def __init__(self, warmup_batches, init_lr, verbose=0):
        """Constructor for warmup learning rate scheduler

        Arguments:
            warmup_batches {int} -- Number of batch for warmup.
            init_lr {float} -- Learning rate after warmup.

        Keyword Arguments:
            verbose {int} -- 0: quiet, 1: update messages. (default: {0})
        """

        super(WarmUpLearningRateScheduler, self).__init__()
        self.warmup_batches = warmup_batches
        self.init_lr = init_lr
        self.verbose = verbose
        self.batch_count = 0
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.batch_count = self.batch_count + 1
        lr = K.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        if self.batch_count <= self.warmup_batches:
            lr = self.batch_count*self.init_lr/self.warmup_batches
            K.set_value(self.model.optimizer.lr, lr)
            if self.verbose > 0:
                print('\nBatch %05d: WarmUpLearningRateScheduler setting learning '
                      'rate to %s.' % (self.batch_count + 1, lr))


# Create a model.
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Number of training samples.
sample_count = 12

# Total epochs to train.
epochs = 7

# Number of warmup epochs.
warmup_epoch = 5

# Training batch size, set small value here for demonstration purpose.
batch_size = 4

# Generate dummy data.
data = np.random.random((sample_count, 100))
labels = np.random.randint(10, size=(sample_count, 1))

# Convert labels to categorical one-hot encoding.
one_hot_labels = keras.utils.to_categorical(labels, num_classes=10)

# Compute the number of warmup batches.
warmup_batches = warmup_epoch * sample_count / batch_size

# Create the Learning rate scheduler.
warm_up_lr = WarmUpLearningRateScheduler(warmup_batches, init_lr=0.001)

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, one_hot_labels, epochs=epochs, batch_size=batch_size,
          verbose=0, callbacks=[warm_up_lr])

warm_up_lr.learning_rates now contains an array of scheduled learning rate for each training batch, let's visualize it.

warmup

Zero γ last batch normalization layer for each ResNet block

Batch normalization scales a batch of inputs with γ and shifts with β, Both γ and β are learnable parameters whose elements are initialized to 1s and 0s, respectively in Keras by default.

In the zero γ initialization heuristic, we initialize γ = 0 for all BN layers that sit at the end of a residual block. Therefore, all residual blocks just return their inputs, mimics network that has less number of layers and is easier to train at the initial stage.

Given an identity ResNet block, when the last BN's γ is initialized as zero, this block will only pass the shortcut inputs to downstream layers.
zero_gamma

You can see how this ResNet block is implemented in Keras, and the only change is the line, gamma_initializer='zeros' for the BatchNormalization layer.

from tensorflow.keras.layers import Conv2D, add
from tensorflow.keras.layers import BatchNormalization, Activation

def add_layer(X, channel_num, name):
    """ResNet block

    Arguments:
        X {tensor} -- Input tensor.
        channel_num {int} -- Number of Conv2D output channels.
        name {str} -- Name of this ResNet block.

    Returns:
        tensor -- Output tensor.
    """

    conv_1 = Conv2D(channel_num, (3, 3), padding='same',
                    name=name + '0_conv')(X)
    conv_1 = BatchNormalization(name=name + '0_bn')(conv_1)
    conv_1 = Activation('relu', name=name + '0_relu')(conv_1)
    conv_1 = Conv2D(channel_num, (3, 3), padding='same',
                    name=name + '1_conv')(conv_1)
    # Zero gamma - Last BN for each ResNet block, easier to train at the initial stage.
    conv_1 = BatchNormalization(
        gamma_initializer='zeros', name=name + '1_bn')(conv_1)
    merge_data = add([conv_1, X], name=name + '1_add')
    out = Activation('relu', name=name + '2_conv')(merge_data)
    return out

No bias decay

The standard weight decay applying an L2 regularization to all parameters drives their values towards 0. It consists of applying penalties on layer weights. Then the penalties are applied to the loss function.

It’s recommended only to apply the regularization to weights to avoid overfitting. Other parameters, including the biases and γ and β in BN layers, are left unregularized.

In Keras, it is effortless to apply the L2 regularization to kernel weights. The option bias_regularizer is also available but not recommended.

from tensorflow.keras import regularizers
model.add(Conv2D(64, (3, 3),
                 kernel_regularizer=regularizers.l2(0.01)))

Training Refinements

Cosine Learning Rate Decay

After the learning rate warmup stage described earlier, we typically steadily decrease its value from the initial learning rate.  Compared to some widely used strategies including exponential decay and step decay, the cosine decay decreases the learning rate slowly at the beginning, and then
becomes almost linear decreasing in the middle, and slows down again at the end.
 It potentially improves the training progress.
cosine_decay

Here is a complete example of a cosine learning rate scheduler with warmup stage in Keras, the scheduler updates the learning rate at the granularity of every update step.

import numpy as np
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import backend as K


def cosine_decay_with_warmup(global_step,
                             learning_rate_base,
                             total_steps,
                             warmup_learning_rate=0.0,
                             warmup_steps=0,
                             hold_base_rate_steps=0):
    """Cosine decay schedule with warm up period.

    Cosine annealing learning rate as described in:
      Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
      ICLR 2017. https://arxiv.org/abs/1608.03983
    In this schedule, the learning rate grows linearly from warmup_learning_rate
    to learning_rate_base for warmup_steps, then transitions to a cosine decay
    schedule.

    Arguments:
        global_step {int} -- global step.
        learning_rate_base {float} -- base learning rate.
        total_steps {int} -- total number of training steps.

    Keyword Arguments:
        warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
        warmup_steps {int} -- number of warmup steps. (default: {0})
        hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
                                    before decaying. (default: {0})
    Returns:
      a float representing learning rate.

    Raises:
      ValueError: if warmup_learning_rate is larger than learning_rate_base,
        or if warmup_steps is larger than total_steps.
    """

    if total_steps < warmup_steps:
        raise ValueError('total_steps must be larger or equal to '
                         'warmup_steps.')
    learning_rate = 0.5 * learning_rate_base * (1 + np.cos(
        np.pi *
        (global_step - warmup_steps - hold_base_rate_steps
         ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
    if hold_base_rate_steps > 0:
        learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,
                                 learning_rate, learning_rate_base)
    if warmup_steps > 0:
        if learning_rate_base < warmup_learning_rate:
            raise ValueError('learning_rate_base must be larger or equal to '
                             'warmup_learning_rate.')
        slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
        warmup_rate = slope * global_step + warmup_learning_rate
        learning_rate = np.where(global_step < warmup_steps, warmup_rate,
                                 learning_rate)
    return np.where(global_step > total_steps, 0.0, learning_rate)


class WarmUpCosineDecayScheduler(keras.callbacks.Callback):
    """Cosine decay with warmup learning rate scheduler
    """

    def __init__(self,
                 learning_rate_base,
                 total_steps,
                 global_step_init=0,
                 warmup_learning_rate=0.0,
                 warmup_steps=0,
                 hold_base_rate_steps=0,
                 verbose=0):
        """Constructor for cosine decay with warmup learning rate scheduler.

    Arguments:
        learning_rate_base {float} -- base learning rate.
        total_steps {int} -- total number of training steps.

    Keyword Arguments:
        global_step_init {int} -- initial global step, e.g. from previous checkpoint.
        warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
        warmup_steps {int} -- number of warmup steps. (default: {0})
        hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
                                    before decaying. (default: {0})
        verbose {int} -- 0: quiet, 1: update messages. (default: {0})
        """

        super(WarmUpCosineDecayScheduler, self).__init__()
        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.global_step = global_step_init
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_steps
        self.hold_base_rate_steps = hold_base_rate_steps
        self.verbose = verbose
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = K.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = cosine_decay_with_warmup(global_step=self.global_step,
                                      learning_rate_base=self.learning_rate_base,
                                      total_steps=self.total_steps,
                                      warmup_learning_rate=self.warmup_learning_rate,
                                      warmup_steps=self.warmup_steps,
                                      hold_base_rate_steps=self.hold_base_rate_steps)
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (self.global_step + 1, lr))


# Create a model.
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Number of training samples.
sample_count = 12

# Total epochs to train.
epochs = 100

# Number of warmup epochs.
warmup_epoch = 10

# Training batch size, set small value here for demonstration purpose.
batch_size = 4

# Base learning rate after warmup.
learning_rate_base = 0.001

total_steps = int(epochs * sample_count / batch_size)

# Compute the number of warmup batches.
warmup_steps = int(warmup_epoch * sample_count / batch_size)

# Generate dummy data.
data = np.random.random((sample_count, 100))
labels = np.random.randint(10, size=(sample_count, 1))

# Convert labels to categorical one-hot encoding.
one_hot_labels = keras.utils.to_categorical(labels, num_classes=10)

# Compute the number of warmup batches.
warmup_batches = warmup_epoch * sample_count / batch_size

# Create the Learning rate scheduler.
warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
                                        total_steps=total_steps,
                                        warmup_learning_rate=0.0,
                                        warmup_steps=warmup_steps,
                                        hold_base_rate_steps=0)

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, one_hot_labels, epochs=epochs, batch_size=batch_size,
          verbose=0, callbacks=[warm_up_lr])

import matplotlib.pyplot as plt
plt.plot(warm_up_lr.learning_rates)
plt.xlabel('Step', fontsize=20)
plt.ylabel('lr', fontsize=20)
plt.axis([0, total_steps, 0, learning_rate_base*1.1])
plt.xticks(np.arange(0, total_steps, 50))
plt.grid()
plt.title('Cosine decay with warmup', fontsize=20)
plt.show()

You are opted to use the hold_base_rate_steps argument in the scheduler which as its name suggests, holds the base learning rate for a specific number of steps before carrying on with the cosine decay. The resulting learning rate schedule will have a plateau looks like below.

cosine_decay_hold

Label Smoothing

Compared to original one-hot encoded inputs, label smoothing changes the construction of the true probability to,

label_smoothing

Where ε is a small constant and K is the total number of classes. Label Smoothing encourages a finite output from the fully-connected layer to make the model generalize better and less prone to overfitting. It is also an efficient and theoretically grounded solution for label noise. You can read more about the discussion here.
Here is how you can apply label smoothing on one-hot labels before training a classifier.

from tensorflow.keras.datasets import mnist
from tensorflow import keras
import numpy as np


def smooth_labels(y, smooth_factor):
    '''Convert a matrix of one-hot row-vector labels into smoothed versions.

    # Arguments
        y: matrix of one-hot row-vector labels to be smoothed
        smooth_factor: label smoothing factor (between 0 and 1)

    # Returns
        A matrix of smoothed labels.
    '''
    assert len(y.shape) == 2
    if 0 <= smooth_factor <= 1:
        # label smoothing ref: https://www.robots.ox.ac.uk/~vgg/rg/papers/reinception.pdf
        y *= 1 - smooth_factor
        y += smooth_factor / y.shape[1]
    else:
        raise Exception(
            'Invalid label smoothing factor: ' + str(smooth_factor))
    return y


num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print('Before smoothing: {}'.format(y_train[0]))
smooth_labels(y_train, .1)
print('After smoothing: {}'.format(y_train[0]))

Results

Before smoothing: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
After smoothing: [0.01  0.01  0.01  0.01  0.01  0.90999997  0.01  0.01  0.01  0.01]

Conclusion and further reading

There are two more training refinements not covered in this post, namely,

  • Knowledge distillation which leverages a pre-trained larger model's outputs to train a smaller model.
  • Mixup training, similar to augmentation in a sense, it creates more data by forming new example through weighted linear interpolation of two examples. Take a look at implementing this in another post.

Read the paper https://arxiv.org/abs/1812.01187v2 for detail information about each trick.

Source code available on my GitHub.

Current rating: 4.7

Comments