(Comments)
The focal loss was proposed for dense object detection task early this year. It enables training highly accurate dense object detectors with an imbalance between foreground and background classes at 1:1000 scale. This tutorial will show you how to apply focal loss to train a multi-class classifier model given highly imbalanced datasets.
Let's first take a look at other treatments for imbalanced datasets, and how focal loss comes to solve the issue.
In multi-class classification, a balanced dataset has target labels that are evenly distributed. If one class has overwhelmingly more samples than another, it can be seen as an imbalanced dataset. This imbalance causes two problems:
A common solution is to perform some form of hard negative mining that samples hard examples during training or more complex sampling/reweighing schemes.
For image classification specific, data augmentation techniques are also variable to create synthetic data for under-represented classes.
The focal loss is designed to address class imbalance by down-weighting inliers (easy examples) such that their contribution to the total loss is small even if their number is large. It focuses on training a sparse set of hard examples.
For demonstration, we will build a classifier for the fraud detection dataset on Kaggle with extreme class imbalance with total 6354407 normal and 8213 fraud cases, or 733:1. With such highly imbalanced datasets, the model can just take the easy route by guessing "normal" for all inputs to achieve an accuracy of 733/(733+1) = 99.86%. However, we want the model to detect the rare fraud cases.
To prove the focal loss to be more effective than commonly applied techniques, let's set up a baseline model trained with class_weight which tells the model to "pay more attention" to samples from an under-represented fraud class.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
model = Sequential()
model.add(Dense(10, input_dim=input_dim, activation='relu', name='input'))
model.add(Dense(20, activation='relu', name='fc1'))
model.add(Dense(10, activation='relu', name='fc2'))
model.add(Dense(nb_classes, activation='softmax', name='output'))
model.compile(loss='categorical_crossentropy',
optimizer='nadam',
metrics=['accuracy'])
# from sklearn.utils import class_weight
# class_weight = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)
class_weight = {0 : 1., 1: 20.}
model.fit(X_train, y_train, epochs=3, batch_size=1000, class_weight=class_weight)
The baseline model achieved an accuracy of 99.87%, just slightly better than taking the "easy route" by guessing all normal.
We also plot the confusing matrix to describe the performance of a classifier given the reserved test set. You can see there are total 1140+480=1620 miss-classified cases.
Now let's apply focal loss to the same model. You can see how to define the focal loss as a custom loss function for Keras below.
def focal_loss(gamma=2., alpha=4.):
gamma = float(gamma)
alpha = float(alpha)
def focal_loss_fixed(y_true, y_pred):
"""Focal loss for multi-classification
FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
Notice: y_pred is probability after softmax
gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
Focal Loss for Dense Object Detection
https://arxiv.org/abs/1708.02002
Arguments:
y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls]
y_pred {tensor} -- model's output, shape of [batch_size, num_cls]
Keyword Arguments:
gamma {float} -- (default: {2.0})
alpha {float} -- (default: {4.0})
Returns:
[tensor] -- loss.
"""
epsilon = 1.e-9
y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
model_out = tf.add(y_pred, epsilon)
ce = tf.multiply(y_true, -tf.log(model_out))
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
fl = tf.multiply(alpha, tf.multiply(weight, ce))
reduced_fl = tf.reduce_max(fl, axis=1)
return tf.reduce_mean(reduced_fl)
return focal_loss_fixed
model.compile(loss=focal_loss(alpha=1),
optimizer='nadam',
metrics=['accuracy'])
model.fit(X_train, y_train, epochs=3, batch_size=1000)
There are two adjustable parameters for focal loss.
Now let's compare the performance with the previous classifier.
Focal loss model:
In this quick tutorial, we introduced a new tool for your arsenal to handle a highly imbalanced dataset - focal loss. A concrete example shows you how to adopt the focal loss to your classification model in Keras API.
You can find the full source code for this post on my GitHub.
For a detailed description of focal loss, you can read the paper, https://arxiv.org/abs/1708.02002.
Comments