06. Train Faster-RCNN end-to-end on PASCAL VOC

This tutorial goes through the basic steps of training a Faster-RCNN [Ren15] object detection model provided by GluonCV.

Specifically, we show how to build a state-of-the-art Faster-RCNN model by stacking GluonCV components.

It is highly recommended to read the original papers [Girshick14], [Girshick15], [Ren15] to learn more about the ideas behind Faster R-CNN. Appendix from [He16] and experiment detail from [Lin17] may also be useful reference.

Hint

You can skip the rest of this tutorial and start training your Faster-RCNN model right away by downloading this script:

Download train_faster_rcnn.py

Example usage:

Train a default resnet50_v1b model with Pascal VOC on GPU 0:

python train_faster_rcnn.py --gpus 0

Train a resnet50_v1b model on GPU 0,1,2,3:

python train_faster_rcnn.py --gpus 0,1,2,3 --network resnet50_v1b

Check the supported arguments:

python train_faster_rcnn.py --help

Hint

Since lots of contents in this tutorial is very similar to 04. Train SSD on Pascal VOC dataset, you can skip any part if you feel comfortable.

Dataset

Please first go through this Prepare PASCAL VOC datasets tutorial to setup Pascal VOC dataset on your disk. Then, we are ready to load training and validation images.

from gluoncv.data import VOCDetection

# typically we use 2007+2012 trainval splits for training data
train_dataset = VOCDetection(splits=[(2007, 'trainval'), (2012, 'trainval')])
# and use 2007 test as validation data
val_dataset = VOCDetection(splits=[(2007, 'test')])

print('Training images:', len(train_dataset))
print('Validation images:', len(val_dataset))

Out:

Training images: 16551
Validation images: 4952

Data transform

We can read an image-label pair from the training dataset:

train_image, train_label = train_dataset[6]
bboxes = train_label[:, :4]
cids = train_label[:, 4:5]
print('image:', train_image.shape)
print('bboxes:', bboxes.shape, 'class ids:', cids.shape)

Out:

image: (375, 500, 3)
bboxes: (2, 4) class ids: (2, 1)

Plot the image, together with the bounding box labels:

from matplotlib import pyplot as plt
from gluoncv.utils import viz

ax = viz.plot_bbox(train_image.asnumpy(), bboxes, labels=cids, class_names=train_dataset.classes)
plt.show()
train faster rcnn voc

Validation images are quite similar to training because they were basically split randomly to different sets

train faster rcnn voc

For Faster-RCNN networks, the only data augmentation is horizontal flip.

from gluoncv.data.transforms import presets
from gluoncv import utils
from mxnet import nd
short, max_size = 600, 1000  # resize image to short side 600 px, but keep maximum length within 1000
train_transform = presets.rcnn.FasterRCNNDefaultTrainTransform(short, max_size)
val_transform = presets.rcnn.FasterRCNNDefaultValTransform(short, max_size)
utils.random.seed(233)  # fix seed in this tutorial

We apply transforms to train image

Out:

tensor shape: (3, 600, 800)
box and id shape: (2, 6)

Images in tensor are distorted because they no longer sit in (0, 255) range. Let’s convert them back so we can see them clearly.

train_image2 = train_image2.transpose((1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array(
    (0.485, 0.456, 0.406))
train_image2 = (train_image2 * 255).asnumpy().astype('uint8')
ax = viz.plot_bbox(train_image2, train_label2[:, :4],
                   labels=train_label2[:, 4:5],
                   class_names=train_dataset.classes)
plt.show()
train faster rcnn voc

Data Loader

We will iterate through the entire dataset many times during training. Keep in mind that raw images have to be transformed to tensors (mxnet uses BCHW format) before they are fed into neural networks.

A handy DataLoader would be very convenient for us to apply different transforms and aggregate data into mini-batches.

Because Faster-RCNN handles raw images with various aspect ratios and various shapes, we provide a gluoncv.data.batchify.Append, which neither stack or pad images, but instead return lists. In such way, image tensors and labels returned have their own shapes, unaware of the rest in the same batch.

from gluoncv.data.batchify import Tuple, Append, FasterRCNNTrainBatchify
from mxnet.gluon.data import DataLoader

batch_size = 2  # for tutorial, we use smaller batch-size
num_workers = 0  # you can make it larger(if your CPU has more cores) to accelerate data loading

# behavior of batchify_fn: stack images, and pad labels
batchify_fn = Tuple(Append(), Append())
train_loader = DataLoader(train_dataset.transform(train_transform), batch_size, shuffle=True,
                          batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
val_loader = DataLoader(val_dataset.transform(val_transform), batch_size, shuffle=False,
                        batchify_fn=batchify_fn, last_batch='keep', num_workers=num_workers)

for ib, batch in enumerate(train_loader):
    if ib > 3:
        break
    print('data 0:', batch[0][0].shape, 'label 0:', batch[1][0].shape)
    print('data 1:', batch[0][1].shape, 'label 1:', batch[1][1].shape)

Out:

data 0: (1, 3, 600, 800) label 0: (1, 5, 6)
data 1: (1, 3, 600, 901) label 1: (1, 9, 6)
data 0: (1, 3, 600, 800) label 0: (1, 2, 6)
data 1: (1, 3, 562, 1000) label 1: (1, 1, 6)
data 0: (1, 3, 600, 904) label 0: (1, 1, 6)
data 1: (1, 3, 600, 888) label 1: (1, 2, 6)
data 0: (1, 3, 600, 901) label 0: (1, 1, 6)
data 1: (1, 3, 600, 901) label 1: (1, 1, 6)

Faster-RCNN Network

GluonCV’s Faster-RCNN implementation is a composite Gluon HybridBlock gluoncv.model_zoo.FasterRCNN. In terms of structure, Faster-RCNN networks are composed of base feature extraction network, Region Proposal Network(including its own anchor system, proposal generator), region-aware pooling layers, class predictors and bounding box offset predictors.

Gluon Model Zoo has a few built-in Faster-RCNN networks, more on the way. You can load your favorite one with one simple line of code:

Hint

To avoid downloading model in this tutorial, we set pretrained_base=False, in practice we usually want to load pre-trained imagenet models by setting pretrained_base=True.

from gluoncv import model_zoo

net = model_zoo.get_model('faster_rcnn_resnet50_v1b_voc', pretrained_base=False)
print(net)

Out:

FasterRCNN(
  (features): HybridSequential(
    (0): Conv2D(None -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
    (2): Activation(relu)
    (3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    (4): HybridSequential(
      (0): BottleneckV1b(
        (conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu3): Activation(relu)
        (downsample): HybridSequential(
          (0): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        )
      )
      (1): BottleneckV1b(
        (conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu3): Activation(relu)
      )
      (2): BottleneckV1b(
        (conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu3): Activation(relu)
      )
    )
    (5): HybridSequential(
      (0): BottleneckV1b(
        (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu3): Activation(relu)
        (downsample): HybridSequential(
          (0): Conv2D(None -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        )
      )
      (1): BottleneckV1b(
        (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu3): Activation(relu)
      )
      (2): BottleneckV1b(
        (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu3): Activation(relu)
      )
      (3): BottleneckV1b(
        (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu3): Activation(relu)
      )
    )
    (6): HybridSequential(
      (0): BottleneckV1b(
        (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
        (relu3): Activation(relu)
        (downsample): HybridSequential(
          (0): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
        )
      )
      (1): BottleneckV1b(
        (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
        (relu3): Activation(relu)
      )
      (2): BottleneckV1b(
        (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
        (relu3): Activation(relu)
      )
      (3): BottleneckV1b(
        (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
        (relu3): Activation(relu)
      )
      (4): BottleneckV1b(
        (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
        (relu3): Activation(relu)
      )
      (5): BottleneckV1b(
        (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
        (relu3): Activation(relu)
      )
    )
  )
  (top_features): HybridSequential(
    (0): HybridSequential(
      (0): BottleneckV1b(
        (conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
        (relu3): Activation(relu)
        (downsample): HybridSequential(
          (0): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
        )
      )
      (1): BottleneckV1b(
        (conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
        (relu3): Activation(relu)
      )
      (2): BottleneckV1b(
        (conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu1): Activation(relu)
        (conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
        (relu2): Activation(relu)
        (conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
        (relu3): Activation(relu)
      )
    )
  )
  (class_predictor): Dense(None -> 21, linear)
  (box_predictor): Dense(None -> 80, linear)
  (cls_decoder): MultiPerClassDecoder(

  )
  (box_decoder): NormalizedBoxCenterDecoder(
    (corner_to_center): BBoxCornerToCenter(

    )
  )
  (rpn): RPN(
    (anchor_generator): RPNAnchorGenerator(

    )
    (conv1): HybridSequential(
      (0): Conv2D(None -> 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Activation(relu)
    )
    (score): Conv2D(None -> 15, kernel_size=(1, 1), stride=(1, 1))
    (loc): Conv2D(None -> 60, kernel_size=(1, 1), stride=(1, 1))
    (region_proposer): RPNProposal(
      (_box_to_center): BBoxCornerToCenter(

      )
      (_box_decoder): NormalizedBoxCenterDecoder(
        (corner_to_center): BBoxCornerToCenter(

        )
      )
      (_clipper): BBoxClipToImage(

      )
    )
  )
  (sampler): RCNNTargetSampler(

  )
)

Faster-RCNN network is callable with image tensor

import mxnet as mx

x = mx.nd.zeros(shape=(1, 3, 600, 800))
net.initialize()
cids, scores, bboxes = net(x)

Faster-RCNN returns three values, where cids are the class labels, scores are confidence scores of each prediction, and bboxes are absolute coordinates of corresponding bounding boxes.

Faster-RCNN network behave differently during training mode:

from mxnet import autograd

with autograd.train_mode():
    # this time we need ground-truth to generate high quality roi proposals during training
    gt_box = mx.nd.zeros(shape=(1, 1, 4))
    gt_label = mx.nd.zeros(shape=(1, 1, 1))
    cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
        box_targets, box_masks, _ = net(x, gt_box, gt_label)

In training mode, Faster-RCNN returns a lot of intermediate values, which we require to train in an end-to-end flavor, where cls_preds are the class predictions prior to softmax, box_preds are bounding box offsets with one-to-one correspondence to proposals roi is the proposal candidates, samples and matches are the sampling/matching results of RPN anchors. rpn_score and rpn_box are the raw outputs from RPN’s convolutional layers. and anchors are absolute coordinates of corresponding anchors boxes.

Training losses

There are four losses involved in end-to-end Faster-RCNN training.

# the loss to penalize incorrect foreground/background prediction
rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
# the loss to penalize inaccurate anchor boxes
rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
# the loss to penalize incorrect classification prediction.
rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
# and finally the loss to penalize inaccurate proposals
rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1

RPN training targets

To speed up training, we let CPU to pre-compute RPN training targets. This is especially nice when your CPU is powerful and you can use -j num_workers to utilize multi-core CPU.

If we provide network to the training transform function, it will compute training targets

train_transform = presets.rcnn.FasterRCNNDefaultTrainTransform(short, max_size, net)
# Return images, labels, rpn_cls_targets, rpn_box_targets, rpn_box_masks loosely
batchify_fn = FasterRCNNTrainBatchify(net)
# For the next part, we only use batch size 1
batch_size = 1
train_loader = DataLoader(train_dataset.transform(train_transform), batch_size, shuffle=True,
                          batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)

This time we can see the data loader is actually returning the training targets for us. Then it is very naturally a gluon training loop with Trainer and let it update the weights.

for ib, batch in enumerate(train_loader):
    if ib > 0:
        break
    with autograd.train_mode():
        for data, label, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(*batch):
            label = label.expand_dims(0)
            gt_label = label[:, :, 4:5]
            gt_box = label[:, :, :4]
            print('data:', data.shape)
            # box and class labels
            print('box:', gt_box.shape)
            print('label:', gt_label.shape)
            # -1 marks ignored label
            print('rpn cls label:', rpn_cls_targets.shape)
            # mask out ignored box label
            print('rpn box label:', rpn_box_targets.shape)
            print('rpn box mask:', rpn_box_masks.shape)

Out:

data: (3, 600, 800)
box: (1, 6, 4)
label: (1, 6, 1)
rpn cls label: (1, 28500)
rpn box label: (1, 28500, 4)
rpn box mask: (1, 28500, 4)

RCNN training targets

RCNN targets are generated with the intermediate outputs with the stored target generator.

for ib, batch in enumerate(train_loader):
    if ib > 0:
        break
    with autograd.train_mode():
        for data, label, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(*batch):
            label = label.expand_dims(0)
            gt_label = label[:, :, 4:5]
            gt_box = label[:, :, :4]
            # network forward
            cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
                box_targets, box_masks, _ = net(data.expand_dims(0), gt_box, gt_label)

            print('data:', data.shape)
            # box and class labels
            print('box:', gt_box.shape)
            print('label:', gt_label.shape)
            # rcnn does not have ignored label
            print('rcnn cls label:', cls_targets.shape)
            # mask out ignored box label
            print('rcnn box label:', box_targets.shape)
            print('rcnn box mask:', box_masks.shape)

Out:

data: (3, 600, 800)
box: (1, 2, 4)
label: (1, 2, 1)
rcnn cls label: (1, 128)
rcnn box label: (1, 32, 20, 4)
rcnn box mask: (1, 32, 20, 4)

Training loop

After we have defined loss function and generated training targets, we can write the training loop.

for ib, batch in enumerate(train_loader):
    if ib > 0:
        break
    with autograd.record():
        for data, label, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(*batch):
            label = label.expand_dims(0)
            gt_label = label[:, :, 4:5]
            gt_box = label[:, :, :4]
            # network forward
            cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
                box_targets, box_masks, _ = net(data.expand_dims(0), gt_box, gt_label)

            # losses of rpn
            rpn_score = rpn_score.squeeze(axis=-1)
            num_rpn_pos = (rpn_cls_targets >= 0).sum()
            rpn_loss1 = rpn_cls_loss(rpn_score, rpn_cls_targets,
                                     rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
            rpn_loss2 = rpn_box_loss(rpn_box, rpn_box_targets,
                                     rpn_box_masks) * rpn_box.size / num_rpn_pos

            # losses of rcnn
            num_rcnn_pos = (cls_targets >= 0).sum()
            rcnn_loss1 = rcnn_cls_loss(cls_preds, cls_targets,
                                       cls_targets >= 0) * cls_targets.size / cls_targets.shape[
                             0] / num_rcnn_pos
            rcnn_loss2 = rcnn_box_loss(box_preds, box_targets, box_masks) * box_preds.size / \
                         box_preds.shape[0] / num_rcnn_pos

        # some standard gluon training steps:
        # autograd.backward([rpn_loss1, rpn_loss2, rcnn_loss1, rcnn_loss2])
        # trainer.step(batch_size)

Hint

Please checkout the full training script for complete implementation.

References

Girshick14

Ross Girshick and Jeff Donahue and Trevor Darrell and Jitendra Malik. Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation. CVPR 2014.

Girshick15

Ross Girshick. Fast {R-CNN}. ICCV 2015.

Ren15(1,2)

Shaoqing Ren and Kaiming He and Ross Girshick and Jian Sun. Faster {R-CNN}: Towards Real-Time Object Detection with Region Proposal Networks. NIPS 2015.

He16

Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun. Deep Residual Learning for Image Recognition. CVPR 2016.

Lin17

Tsung-Yi Lin and Piotr Dollár and Ross Girshick and Kaiming He and Bharath Hariharan and Serge Belongie. Feature Pyramid Networks for Object Detection. CVPR 2017.

Total running time of the script: ( 0 minutes 28.802 seconds)

Gallery generated by Sphinx-Gallery