How to train Detectron2 with Custom COCO Datasets



Along with the latest PyTorch 1.3 release came with the next generation ground-up rewrite of its previous object detection framework, now called Detectron2. This tutorial will help you get started with this framework by training an instance segmentation model with your custom COCO datasets. If you want to know how to create COCO datasets, please read my previous post - How to create custom COCO data set for instance segmentation.

For a quick start, we will do our experiment in a Colab Notebook so you don't need to worry about setting up the development environment on your own machine before getting comfortable with Pytorch 1.3 and Detectron2.

Install Detectron2

In the Colab notebook, just run those 4 lines to install the latest Pytorch 1.3 and Detectron2.

!pip install -U torch torchvision
!pip install git+
!git clone detectron2_repo
!pip install -e detectron2_repo

Click "RESTART RUNTIME" in the cell's output to let your installation take effect.


Register a COCO dataset

To tell Detectron2 how to obtain your dataset, we are going to "register" it.

To demonstrate this process, we use the fruits nuts segmentation dataset which only has 3 classes: data, fig, and hazelnut. We'll train a segmentation model from an existing model pre-trained on the COCO dataset, available in detectron2's model zoo.

You can download the dataset like this.

# download, decompress the data
!unzip > /dev/null

Or you can upload your own dataset from here.


Register the fruits_nuts dataset to detectron2, following the detectron2 custom dataset tutorial.

from import register_coco_instances
register_coco_instances("fruits_nuts", {}, "./data/trainval.json", "./data/images")

Each dataset is associated with some metadata. In our case, it is accessible by calling fruits_nuts_metadata = MetadataCatalog.get("fruits_nuts"), you will get

Metadata(evaluator_type='coco', image_root='./data/images', json_file='./data/trainval.json', name='fruits_nuts',
         thing_classes=['date', 'fig', 'hazelnut'], thing_dataset_id_to_contiguous_id={1: 0, 2: 1, 3: 2})

To get the actual internal representation of the catalog stores information about the datasets and how to obtain them, you can call dataset_dicts = DatasetCatalog.get("fruits_nuts"). The internal format uses one dict to represent the annotations of one image.

To verify the data loading is correct, let's visualize the annotations of randomly selected samples in the dataset:

import random
from detectron2.utils.visualizer import Visualizer

for d in random.sample(dataset_dicts, 3):
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=fruits_nuts_metadata, scale=0.5)
    vis = visualizer.draw_dataset_dict(d)
    cv2_imshow(vis.get_image()[:, :, ::-1])

One of the images might show this.


Train the model

Now, let's fine-tune a coco-pretrained R50-FPN Mask R-CNN model on the fruits_nuts dataset. It takes ~6 minutes to train 300 iterations on Colab's K80 GPU.

from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os

cfg = get_cfg()
cfg.DATASETS.TRAIN = ("fruits_nuts",)
cfg.DATASETS.TEST = ()  # no metrics implemented for this dataset
cfg.MODEL.WEIGHTS = "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"  # initialize from model zoo
cfg.SOLVER.BASE_LR = 0.02
)  # 300 iterations seems good enough, but you can certainly train longer
)  # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3  # 3 classes (data, fig, hazelnut)

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)

In case you switch to your own datasets, change the number of classes, learning rate, or max iterations accordingly.


Make a prediction

Now, we perform inference with the trained model on the fruits_nuts dataset. First, let's create a predictor using the model we just trained:

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set the testing threshold for this model
cfg.DATASETS.TEST = ("fruits_nuts", )
predictor = DefaultPredictor(cfg)

Then, we randomly select several samples to visualize the prediction results.

from detectron2.utils.visualizer import ColorMode

for d in random.sample(dataset_dicts, 3):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    cv2_imshow(v.get_image()[:, :, ::-1])

Here is what you get with a sample image with prediction overlayed.


Conclusion and further thought

You might have read my previous tutorial on a similar object detection framework named MMdetection also built upon PyTorch. So how is Detectron2 compared with it? Here are my few thoughts.

Both frameworks are easy to config with a config file that describes how you want to train a model. Detectron2's YAML config files are more efficient for two reasons. First, You can reuse configs by making a "base" config first and build final training config files upon this base config file which reduces duplicated code. Second, the config file can be loaded first and allows any further modification as necessary in Python code which makes it more flexible.

What about the inference speed? Simply put, Detectron2 is slightly faster than MMdetection for the same Mask RCNN Resnet50 FPN model. MMdetection gets 2.45 FPS while Detectron2 achieves 2.59 FPS, or a 5.7% speed boost on inferencing a single image. Benchmark based on the following code.

import time
times = []
for i in range(20):
    start_time = time.time()
    outputs = predictor(im)
    delta = time.time() - start_time
mean_delta = np.array(times).mean()
fps = 1 / mean_delta
print("Average(sec):{:.2f},fps:{:.2f}".format(mean_delta, fps))

So, you have it, Detectron2 make it super simple for you to train a custom instance segmentation model with custom datasets. You might find the following resources helpful.

My previous post How to create custom COCO data set for instance segmentation.

My previous post - How to train an object detection model with mmdetection.

Detectron2 GitHub repository.

The runnable Colab Notebook for this post.

Current rating: 4.0