(Comments)

The focal loss was proposed for dense object detection task early this year. It enables training highly accurate dense object detectors with an imbalance between foreground and background classes at **1:1000** scale. This tutorial will show you how to apply focal loss to train a multi-class classifier model given highly imbalanced datasets.

Let's first take a look at other treatments for imbalanced datasets, and how focal loss comes to solve the issue.

In multi-class classification, a balanced dataset has target labels that are evenly distributed. If one class has overwhelmingly more samples than another, it can be seen as an imbalanced dataset. This imbalance causes two problems:

- Training is inefficient as most samples are easy examples that contribute no useful learning signal;
- The easy examples can overwhelm training and lead to degenerate models.

A common solution is to perform some form of hard negative mining that samples hard examples during training or more complex sampling/reweighing schemes.

For image classification specific, data augmentation techniques are also variable to create synthetic data for under-represented classes.

The focal loss is designed to address class imbalance by down-weighting inliers (easy examples) such that their contribution to the total loss is small even if their number is large. It focuses on training a sparse set of hard examples.

For demonstration, we will build a classifier for the fraud detection dataset on Kaggle with extreme class imbalance with total 6354407 normal and 8213 fraud cases, or 733:1. With such highly imbalanced datasets, the model can just take the easy route by guessing "normal" for all inputs to achieve an accuracy of 733/(733+1) = **99.86%**. However, we want the model to detect the rare fraud cases.

To prove the focal loss to be more effective than commonly applied techniques, let's set up a baseline model trained with **class_weight** which tells the model to "pay more attention" to samples from an under-represented fraud class.

```
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
model = Sequential()
model.add(Dense(10, input_dim=input_dim, activation='relu', name='input'))
model.add(Dense(20, activation='relu', name='fc1'))
model.add(Dense(10, activation='relu', name='fc2'))
model.add(Dense(nb_classes, activation='softmax', name='output'))
model.compile(loss='categorical_crossentropy',
optimizer='nadam',
metrics=['accuracy'])
# from sklearn.utils import class_weight
# class_weight = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)
class_weight = {0 : 1., 1: 20.}
model.fit(X_train, y_train, epochs=3, batch_size=1000, class_weight=class_weight)
```

The baseline model achieved an accuracy of **99.87%**, just slightly better than taking the "easy route" by guessing all normal.

We also plot the confusing matrix to describe the performance of a classifier given the reserved test set. You can see there are total **1140+480=1620** miss-classified cases.

Now let's apply focal loss to the same model. You can see how to define the focal loss as a custom loss function for Keras below.

```
def focal_loss(gamma=2., alpha=4.):
gamma = float(gamma)
alpha = float(alpha)
def focal_loss_fixed(y_true, y_pred):
"""Focal loss for multi-classification
FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
Notice: y_pred is probability after softmax
gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
Focal Loss for Dense Object Detection
https://arxiv.org/abs/1708.02002
Arguments:
y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls]
y_pred {tensor} -- model's output, shape of [batch_size, num_cls]
Keyword Arguments:
gamma {float} -- (default: {2.0})
alpha {float} -- (default: {4.0})
Returns:
[tensor] -- loss.
"""
epsilon = 1.e-9
y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
model_out = tf.add(y_pred, epsilon)
ce = tf.multiply(y_true, -tf.log(model_out))
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
fl = tf.multiply(alpha, tf.multiply(weight, ce))
reduced_fl = tf.reduce_max(fl, axis=1)
return tf.reduce_mean(reduced_fl)
return focal_loss_fixed
model.compile(loss=focal_loss(alpha=1),
optimizer='nadam',
metrics=['accuracy'])
model.fit(X_train, y_train, epochs=3, batch_size=1000)
```

There are two adjustable parameters for focal loss.

- The focusing parameter γ(gamma) smoothly adjusts the rate at which easy examples are down-weighted. When γ = 0, focal loss is equivalent to categorical cross-entropy, and as γ is increased the effect of the modulating factor is likewise increased (γ = 2 works best in experiments).
- α(alpha): balances focal loss, yields slightly improved accuracy over the non-α-balanced form.

Now let's compare the performance with the previous classifier.

Focal loss model:

- Accuracy: 99.94%
- A total miss-classified test set samples:
**766+23=789**, cutting down the mistakes by half.

In this quick tutorial, we introduced a new tool for your arsenal to handle a highly imbalanced dataset - focal loss. A concrete example shows you how to adopt the focal loss to your classification model in Keras API.

You can find the full source code for this post on my GitHub.

For a detailed description of focal loss, you can read the paper, https://arxiv.org/abs/1708.02002.

- Recent Advances in Deep Learning for Object Detection - Part 2
- Recent Advances in Deep Learning for Object Detection - Part 1
- How to run Keras model on Jetson Nano in Nvidia Docker container
- How to create custom COCO data set for instance segmentation
- How to create custom COCO data set for object detection

- December (3)
- November (3)
- October (3)
- September (5)
- August (5)
- July (4)
- June (4)
- May (4)
- April (6)
- March (5)
- February (3)
- January (4)

- deep learning (75)
- edge computing (15)
- Keras (47)
- NLP (8)
- python (67)
- PyTorch (6)
- tensorflow (33)

- tutorial (53)
- Sentiment analysis (3)
- keras (34)
- deep learning (55)
- pytorch (1)

- Chengwei (80)

## Comments