04. Train SSD on Pascal VOC dataset

This tutorial goes through the basic building blocks of object detection provided by GluonCV. Specifically, we show how to build a state-of-the-art Single Shot Multibox Detection [Liu16] model by stacking GluonCV components. This is also a good starting point for your own object detection project.

Hint

You can skip the rest of this tutorial and start training your SSD model right away by downloading this script:

Download train_ssd.py

Example usage:

Train a default vgg16_atrous 300x300 model with Pascal VOC on GPU 0:

python train_ssd.py

Train a resnet50_v1 512x512 model on GPU 0,1,2,3:

python train_ssd.py --gpus 0,1,2,3 --network resnet50_v1 --data-shape 512

Check the supported arguments:

python train_ssd.py --help

Dataset

Please first go through this Prepare PASCAL VOC datasets tutorial to setup Pascal VOC dataset on your disk. Then, we are ready to load training and validation images.

from gluoncv.data import VOCDetection
# typically we use 2007+2012 trainval splits for training data
train_dataset = VOCDetection(splits=[(2007, 'trainval'), (2012, 'trainval')])
# and use 2007 test as validation data
val_dataset = VOCDetection(splits=[(2007, 'test')])

print('Training images:', len(train_dataset))
print('Validation images:', len(val_dataset))

Out:

Training images: 16551
Validation images: 4952

Data transform

We can read an image-label pair from the training dataset:

train_image, train_label = train_dataset[0]
bboxes = train_label[:, :4]
cids = train_label[:, 4:5]
print('image:', train_image.shape)
print('bboxes:', bboxes.shape, 'class ids:', cids.shape)

Out:

image: (375, 500, 3)
bboxes: (5, 4) class ids: (5, 1)

Plot the image, together with the bounding box labels:

from matplotlib import pyplot as plt
from gluoncv.utils import viz

ax = viz.plot_bbox(
    train_image.asnumpy(),
    bboxes,
    labels=cids,
    class_names=train_dataset.classes)
plt.show()
train ssd voc

Validation images are quite similar to training because they were basically split randomly to different sets

val_image, val_label = val_dataset[0]
bboxes = val_label[:, :4]
cids = val_label[:, 4:5]
ax = viz.plot_bbox(
    val_image.asnumpy(),
    bboxes,
    labels=cids,
    class_names=train_dataset.classes)
plt.show()
train ssd voc

For SSD networks, it is critical to apply data augmentation (see explanations in paper [Liu16]). We provide tons of image and bounding box transform functions to do that. They are very convenient to use as well.

from gluoncv.data.transforms import presets
from gluoncv import utils
from mxnet import nd
utils.random.seed(233)  # fix seed in this tutorial

apply transforms to train image

Out:

tensor shape: (3, 512, 512)

Images in tensor are distorted because they no longer sit in (0, 255) range. Let’s convert them back so we can see them clearly.

train_image2 = train_image2.transpose(
    (1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array((0.485, 0.456, 0.406))
train_image2 = (train_image2 * 255).clip(0, 255)
ax = viz.plot_bbox(train_image2.asnumpy(), train_label2[:, :4],
                   labels=train_label2[:, 4:5],
                   class_names=train_dataset.classes)
plt.show()
train ssd voc

apply transforms to validation image

val_image2, val_label2 = val_transform(val_image, val_label)
val_image2 = val_image2.transpose(
    (1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array((0.485, 0.456, 0.406))
val_image2 = (val_image2 * 255).clip(0, 255)
ax = viz.plot_bbox(val_image2.clip(0, 255).asnumpy(), val_label2[:, :4],
                   labels=val_label2[:, 4:5],
                   class_names=train_dataset.classes)
plt.show()
train ssd voc

Transforms used in training include random expanding, random cropping, color distortion, random flipping, etc. In comparison, validation transforms are simpler and only resizing and color normalization is used.

Data Loader

We will iterate through the entire dataset many times during training. Keep in mind that raw images have to be transformed to tensors (mxnet uses BCHW format) before they are fed into neural networks. In addition, to be able to run in mini-batches, images must be resized to the same shape.

A handy DataLoader would be very convenient for us to apply different transforms and aggregate data into mini-batches.

Because the number of objects varies a lot across images, we also have varying label sizes. As a result, we need to pad those labels to the same size. To deal with this problem, GluonCV provides gluoncv.data.batchify.Pad, which handles padding automatically. gluoncv.data.batchify.Stack in addition, is used to stack NDArrays with consistent shapes. gluoncv.data.batchify.Tuple is used to handle different behaviors across multiple outputs from transform functions.

from gluoncv.data.batchify import Tuple, Stack, Pad
from mxnet.gluon.data import DataLoader

batch_size = 2  # for tutorial, we use smaller batch-size
# you can make it larger(if your CPU has more cores) to accelerate data loading
num_workers = 0

# behavior of batchify_fn: stack images, and pad labels
batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
train_loader = DataLoader(
    train_dataset.transform(train_transform),
    batch_size,
    shuffle=True,
    batchify_fn=batchify_fn,
    last_batch='rollover',
    num_workers=num_workers)
val_loader = DataLoader(
    val_dataset.transform(val_transform),
    batch_size,
    shuffle=False,
    batchify_fn=batchify_fn,
    last_batch='keep',
    num_workers=num_workers)

for ib, batch in enumerate(train_loader):
    if ib > 3:
        break
    print('data:', batch[0].shape, 'label:', batch[1].shape)

Out:

data: (2, 3, 512, 512) label: (2, 3, 6)
data: (2, 3, 512, 512) label: (2, 1, 6)
data: (2, 3, 512, 512) label: (2, 2, 6)
data: (2, 3, 512, 512) label: (2, 1, 6)

SSD Network

GluonCV’s SSD implementation is a composite Gluon HybridBlock (which means it can be exported to symbol to run in C++, Scala and other language bindings. We will cover this usage in future tutorials). In terms of structure, SSD networks are composed of base feature extraction network, anchor generators, class predictors and bounding box offset predictors.

For more details on how SSD detector works, please refer to our introductory tutorial You can also refer to the original paper to learn more about the intuitions behind SSD.

Gluon Model Zoo has a lot of built-in SSD networks. You can load your favorite one with one simple line of code:

Hint

To avoid downloading models in this tutorial, we set pretrained_base=False, in practice we usually want to load pre-trained imagenet models by setting pretrained_base=True.

from gluoncv import model_zoo
net = model_zoo.get_model('ssd_300_vgg16_atrous_voc', pretrained_base=False)
print(net)

Out:

SSD(
  (features): VGGAtrousExtractor(
    (stages): HybridSequential(
      (0): HybridSequential(
        (0): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Activation(relu)
      )
      (1): HybridSequential(
        (0): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Activation(relu)
      )
      (2): HybridSequential(
        (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Activation(relu)
        (4): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): Activation(relu)
      )
      (3): HybridSequential(
        (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Activation(relu)
        (4): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): Activation(relu)
      )
      (4): HybridSequential(
        (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Activation(relu)
        (4): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): Activation(relu)
      )
      (5): HybridSequential(
        (0): Conv2D(None -> 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
        (1): Activation(relu)
        (2): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1))
        (3): Activation(relu)
      )
    )
    (norm4): Normalize(

    )
    (extras): HybridSequential(
      (0): HybridSequential(
        (0): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): Activation(relu)
      )
      (1): HybridSequential(
        (0): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): Activation(relu)
      )
      (2): HybridSequential(
        (0): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1))
        (3): Activation(relu)
      )
      (3): HybridSequential(
        (0): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1))
        (1): Activation(relu)
        (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1))
        (3): Activation(relu)
      )
    )
  )
  (class_predictors): HybridSequential(
    (0): ConvPredictor(
      (predictor): Conv2D(None -> 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ConvPredictor(
      (predictor): Conv2D(None -> 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ConvPredictor(
      (predictor): Conv2D(None -> 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ConvPredictor(
      (predictor): Conv2D(None -> 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (4): ConvPredictor(
      (predictor): Conv2D(None -> 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (5): ConvPredictor(
      (predictor): Conv2D(None -> 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (box_predictors): HybridSequential(
    (0): ConvPredictor(
      (predictor): Conv2D(None -> 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ConvPredictor(
      (predictor): Conv2D(None -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ConvPredictor(
      (predictor): Conv2D(None -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ConvPredictor(
      (predictor): Conv2D(None -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (4): ConvPredictor(
      (predictor): Conv2D(None -> 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (5): ConvPredictor(
      (predictor): Conv2D(None -> 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (anchor_generators): HybridSequential(
    (0): SSDAnchorGenerator(

    )
    (1): SSDAnchorGenerator(

    )
    (2): SSDAnchorGenerator(

    )
    (3): SSDAnchorGenerator(

    )
    (4): SSDAnchorGenerator(

    )
    (5): SSDAnchorGenerator(

    )
  )
  (bbox_decoder): NormalizedBoxCenterDecoder(

  )
  (cls_decoder): MultiPerClassDecoder(

  )
)

SSD network is a HybridBlock as mentioned before. You can call it with an input as:

import mxnet as mx
x = mx.nd.zeros(shape=(1, 3, 512, 512))
net.initialize()
cids, scores, bboxes = net(x)

SSD returns three values, where cids are the class labels, scores are confidence scores of each prediction, and bboxes are absolute coordinates of corresponding bounding boxes.

SSD network behave differently during training mode:

from mxnet import autograd
with autograd.train_mode():
    cls_preds, box_preds, anchors = net(x)

In training mode, SSD returns three intermediate values, where cls_preds are the class predictions prior to softmax, box_preds are bounding box offsets with one-to-one correspondence to anchors and anchors are absolute coordinates of corresponding anchors boxes, which are fixed since training images use inputs of same dimensions.

Training targets

Unlike a single SoftmaxCrossEntropyLoss used in image classification, the loss used in SSD is more complicated. Don’t worry though, because we have these modules available out of the box.

To speed up training, we let CPU to pre-compute some training targets. This is especially nice when your CPU is powerful and you can use -j num_workers to utilize multi-core CPU.

If we provide anchors to the training transform, it will compute training targets

from mxnet import gluon
train_transform = presets.ssd.SSDDefaultTrainTransform(width, height, anchors)
batchify_fn = Tuple(Stack(), Stack(), Stack())
train_loader = DataLoader(
    train_dataset.transform(train_transform),
    batch_size,
    shuffle=True,
    batchify_fn=batchify_fn,
    last_batch='rollover',
    num_workers=num_workers)

Loss, Trainer and Training pipeline

from gluoncv.loss import SSDMultiBoxLoss
mbox_loss = SSDMultiBoxLoss()
trainer = gluon.Trainer(
    net.collect_params(), 'sgd',
    {'learning_rate': 0.001, 'wd': 0.0005, 'momentum': 0.9})

for ib, batch in enumerate(train_loader):
    if ib > 0:
        break
    print('data:', batch[0].shape)
    print('class targets:', batch[1].shape)
    print('box targets:', batch[2].shape)
    with autograd.record():
        cls_pred, box_pred, anchors = net(batch[0])
        sum_loss, cls_loss, box_loss = mbox_loss(
            cls_pred, box_pred, batch[1], batch[2])
        # some standard gluon training steps:
        # autograd.backward(sum_loss)
        # trainer.step(1)

Out:

data: (2, 3, 512, 512)
class targets: (2, 24656)
box targets: (2, 24656, 4)

This time we can see the data loader is actually returning the training targets for us. Then it is very naturally a gluon training loop with Trainer and let it update the weights.

Please checkout the full training script for complete implementation.

References

Liu16(1,2)

Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg. SSD: Single Shot MultiBox Detector. ECCV 2016.

Total running time of the script: ( 1 minutes 6.861 seconds)

Gallery generated by Sphinx-Gallery