Multi-class classification with focal loss for imbalanced datasets



The focal loss was proposed for dense object detection task early this year. It enables training highly accurate dense object detectors with an imbalance between foreground and background classes at 1:1000 scale. This tutorial will show you how to apply focal loss to train a multi-class classifier model given highly imbalanced datasets.


Let's first take a look at other treatments for imbalanced datasets, and how focal loss comes to solve the issue.

In multi-class classification, a balanced dataset has target labels that are evenly distributed. If one class has overwhelmingly more samples than another, it can be seen as an imbalanced dataset. This imbalance causes two problems:

  • Training is inefficient as most samples are easy examples that contribute no useful learning signal;
  • The easy examples can overwhelm training and lead to degenerate models.

A common solution is to perform some form of hard negative mining that samples hard examples during training or more complex sampling/reweighing schemes.

For image classification specific, data augmentation techniques are also variable to create synthetic data for under-represented classes.

The focal loss is designed to address class imbalance by down-weighting inliers (easy examples) such that their contribution to the total loss is small even if their number is large. It focuses on training a sparse set of hard examples.

Apply focal loss to fraud detection task

For demonstration, we will build a classifier for the fraud detection dataset on Kaggle with extreme class imbalance with total 6354407 normal and 8213 fraud cases, or 733:1. With such highly imbalanced datasets, the model can just take the easy route by guessing "normal" for all inputs to achieve an accuracy of 733/(733+1) = 99.86%. However, we want the model to detect the rare fraud cases.

To prove the focal loss to be more effective than commonly applied techniques, let's set up a baseline model trained with class_weight which tells the model to "pay more attention" to samples from an under-represented fraud class.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential()
model.add(Dense(10, input_dim=input_dim, activation='relu', name='input'))
model.add(Dense(20, activation='relu', name='fc1'))
model.add(Dense(10, activation='relu', name='fc2'))
model.add(Dense(nb_classes, activation='softmax', name='output'))
# from sklearn.utils import class_weight
# class_weight = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)
class_weight = {0 : 1., 1: 20.}, y_train, epochs=3, batch_size=1000, class_weight=class_weight)

The baseline model achieved an accuracy of 99.87%, just slightly better than taking the "easy route" by guessing all normal.

We also plot the confusing matrix to describe the performance of a classifier given the reserved test set. You can see there are total 1140+480=1620 miss-classified cases.


Now let's apply focal loss to the same model. You can see how to define the focal loss as a custom loss function for Keras below.

def focal_loss(gamma=2., alpha=4.):

    gamma = float(gamma)
    alpha = float(alpha)

    def focal_loss_fixed(y_true, y_pred):
        """Focal loss for multi-classification
        Notice: y_pred is probability after softmax
        gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
        d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
        Focal Loss for Dense Object Detection

            y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls]
            y_pred {tensor} -- model's output, shape of [batch_size, num_cls]

        Keyword Arguments:
            gamma {float} -- (default: {2.0})
            alpha {float} -- (default: {4.0})

            [tensor] -- loss.
        epsilon = 1.e-9
        y_true = tf.convert_to_tensor(y_true, tf.float32)
        y_pred = tf.convert_to_tensor(y_pred, tf.float32)

        model_out = tf.add(y_pred, epsilon)
        ce = tf.multiply(y_true, -tf.log(model_out))
        weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
        fl = tf.multiply(alpha, tf.multiply(weight, ce))
        reduced_fl = tf.reduce_max(fl, axis=1)
        return tf.reduce_mean(reduced_fl)
    return focal_loss_fixed

              metrics=['accuracy']), y_train, epochs=3, batch_size=1000)

There are two adjustable parameters for focal loss.

  • The focusing parameter γ(gamma) smoothly adjusts the rate at which easy examples are down-weighted. When γ = 0, focal loss is equivalent to categorical cross-entropy, and as γ is increased the effect of the modulating factor is likewise increased (γ = 2 works best in experiments).
  • α(alpha): balances focal loss, yields slightly improved accuracy over the non-α-balanced form.  

Now let's compare the performance with the previous classifier.

Focal loss model:

  • Accuracy: 99.94%
  • A total miss-classified test set samples: 766+23=789, cutting down the mistakes by half.


Conclusion and further reading.

In this quick tutorial, we introduced a new tool for your arsenal to handle a highly imbalanced dataset - focal loss. A concrete example shows you how to adopt the focal loss to your classification model in Keras API.

You can find the full source code for this post on my GitHub.

For a detailed description of focal loss, you can read the paper,

Current rating: 3.3