How to do mixup training from image files in Keras

(Comments)

mixology

Previously, we introduced a bag of tricks to improve image classification performance with convolutional networks in Keras, this time, we will take a closer look at the last trick called mixup.

What is the mixup training?

The paper mixup: BEYOND EMPIRICAL RISK MINIMIZATION offers an alternative to traditional image augmentation technique like zooming and rotation. By forming a new example through weighted linear interpolation of two existing examples.

mixup-function

(xi; yi) and (xj; yj) are two examples drawn at random from our training data, and λ ∈ [0; 1], in practice,  λ is randomly sampled from the beta distribution, i.e. Beta(α; α).

α ∈ [0.1; 0.4] leads to improved performance, smaller α creates less mixup effect, whereas, for large αmixup leads to underfitting.

As you can see in the following graph, given a small α = 0.2, beta distribution samples more values closer to either 0 and 1, making the mixup result closer to either one of the two examples.

beta

What are the benefits of mixup training?

While the traditional data augmentation like those provided in Keras ImageDataGenerator class consistently leads to improved generalization, the procedure is dataset-dependent, and thus requires the use of expert knowledge.

Besides, data augmentation does not model the relation across examples of different classes.

On the other hand,

  • Mixup is a data-agnostic data augmentation routine.
  • It makes decision boundaries transit linearly from class to class, providing a smoother estimate of uncertainty.
  • It reduces the memorization of corrupt labels,
  • It increases the robustness to the adversarial examples and stabilizes the training of generative adversarial networks.

Mixup image data generator in Keras

Attempting to give mixup a spin? Let's implement an image data generator that reads images from files and works with Keras model.fit_generator() out of the box.

import numpy as np


class MixupImageDataGenerator():
    def __init__(self, generator, directory, batch_size, img_height, img_width, alpha=0.2, subset=None):
        """Constructor for mixup image data generator.

        Arguments:
            generator {object} -- An instance of Keras ImageDataGenerator.
            directory {str} -- Image directory.
            batch_size {int} -- Batch size.
            img_height {int} -- Image height in pixels.
            img_width {int} -- Image width in pixels.

        Keyword Arguments:
            alpha {float} -- Mixup beta distribution alpha parameter. (default: {0.2})
            subset {str} -- 'training' or 'validation' if validation_split is specified in
            `generator` (ImageDataGenerator).(default: {None})
        """

        self.batch_index = 0
        self.batch_size = batch_size
        self.alpha = alpha

        # First iterator yielding tuples of (x, y)
        self.generator1 = generator.flow_from_directory(directory,
                                                        target_size=(
                                                            img_height, img_width),
                                                        class_mode="categorical",
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        subset=subset)

        # Second iterator yielding tuples of (x, y)
        self.generator2 = generator.flow_from_directory(directory,
                                                        target_size=(
                                                            img_height, img_width),
                                                        class_mode="categorical",
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        subset=subset)

        # Number of images across all classes in image directory.
        self.n = self.generator1.samples

    def reset_index(self):
        """Reset the generator indexes array.
        """

        self.generator1._set_index_array()
        self.generator2._set_index_array()

    def on_epoch_end(self):
        self.reset_index()

    def reset(self):
        self.batch_index = 0

    def __len__(self):
        # round up
        return (self.n + self.batch_size - 1) // self.batch_size

    def get_steps_per_epoch(self):
        """Get number of steps per epoch based on batch size and
        number of images.

        Returns:
            int -- steps per epoch.
        """

        return self.n // self.batch_size

    def __next__(self):
        """Get next batch input/output pair.

        Returns:
            tuple -- batch of input/output pair, (inputs, outputs).
        """

        if self.batch_index == 0:
            self.reset_index()

        current_index = (self.batch_index * self.batch_size) % self.n
        if self.n > current_index + self.batch_size:
            self.batch_index += 1
        else:
            self.batch_index = 0

        # random sample the lambda value from beta distribution.
        l = np.random.beta(self.alpha, self.alpha, self.batch_size)

        X_l = l.reshape(self.batch_size, 1, 1, 1)
        y_l = l.reshape(self.batch_size, 1)

        # Get a pair of inputs and outputs from two iterators.
        X1, y1 = self.generator1.next()
        X2, y2 = self.generator2.next()

        # Perform the mixup.
        X = X1 * X_l + X2 * (1 - X_l)
        y = y1 * y_l + y2 * (1 - y_l)
        return X, y

    def __iter__(self):
        while True:
            yield next(self)

The core of the mixup generator consists of a pair of iterators sampling images randomly from directory one batch at a time with the mixup performed in the __next__ method.

Then you can create the training and validation generator for fitting the model, notice that we don't use mixup in the validation generator.

train_dir = "./data"

batch_size = 5
validation_split = 0.3
img_height = 150
img_width = 150
epochs = 10

# Optional additional image augmentation with ImageDataGenerator.
input_imgen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=5,
    width_shift_range=0.05,
    height_shift_range=0,
    shear_range=0.05,
    zoom_range=0,
    brightness_range=(1, 1.3),
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=validation_split)

# Create training and validation generator.
train_generator = MixupImageDataGenerator(generator=input_imgen,
                                          directory=train_dir,
                                          batch_size=batch_size,
                                          img_height=img_height,
                                          img_width=img_height,
                                          subset='training')
validation_generator = input_imgen.flow_from_directory(train_dir,
                                                       target_size=(
                                                           img_height, img_width),
                                                       class_mode="categorical",
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       subset='validation')

print('training steps: ', train_generator.get_steps_per_epoch())
print('validation steps: ', validation_generator.samples // batch_size)


# Build a Keras image classification model as usual.
# ...

# Start the traning.
history = model.fit_generator(
    train_generator,
    steps_per_epoch=train_generator.get_steps_per_epoch(),
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // batch_size,
    epochs=epochs)

We can visualize a batch of mixup images and labels with the following snippet in a Jupyter notebook.

sample_x, sample_y = next(train_generator)
for i in range(batch_size):
    display(image.array_to_img(sample_x[i]))
print(sample_y)

The following picture illustrates how mixup works.

mixup-example

Conclusion and further thoughts

You might be thinking mixing up more than 2 examples at a time might leads to better training, on the contrary, combinations of three or more examples with weights sampled from the multivariate generalization of the beta distribution does not provide further gain, but increases the computation cost of mixup. Moreover, interpolating only between inputs with equal label did not lead to the performance gains of mixup.

Check out the full source code on my Github.

Current rating: 4.6

Comments