Note
Click here to download the full example code
2. Train Mask RCNN end-to-end on MS COCO¶
This tutorial goes through the steps for training a Mask R-CNN [He17] instance segmentation model provided by GluonCV.
Mask R-CNN is an extension to the Faster R-CNN [Ren15] object detection model. As such, this tutorial is also an extension to 06. Train Faster-RCNN end-to-end on PASCAL VOC. We will focus on the extra work on top of Faster R-CNN to show how to use GluonCV components to construct a Mask R-CNN model.
It is highly recommended to read the original papers [Girshick14], [Girshick15], [Ren15], [He17] to learn more about the ideas behind Mask R-CNN. Appendix from [He16] and experiment detail from [Lin17] may also be useful reference.
Hint
Please first go through this Prepare COCO datasets tutorial to setup MSCOCO dataset on your disk.
Hint
You can skip the rest of this tutorial and start training your Mask RCNN model right away by downloading this script:
Example usage:
Train a default resnet50_v1b model with COCO dataset on GPU 0:
python train_mask_rcnn.py --gpus 0
Train on GPU 0,1,2,3:
python train_mask_rcnn.py --gpus 0,1,2,3
Check the supported arguments:
python train_mask_rcnn.py --help
Dataset¶
Make sure COCO dataset has been set up on your disk. Then, we are ready to load training and validation images.
from gluoncv.data import COCOInstance
# typically we use train2017 (i.e. train2014 + minival35k) split as training data
# COCO dataset actually has images without any objects annotated,
# which must be skipped during training to prevent empty labels
train_dataset = COCOInstance(splits='instances_train2017', skip_empty=True)
# and val2014 (i.e. minival5k) test as validation data
val_dataset = COCOInstance(splits='instances_val2017', skip_empty=False)
print('Training images:', len(train_dataset))
print('Validation images:', len(val_dataset))
Out:
loading annotations into memory...
Done (t=13.42s)
creating index...
index created!
loading annotations into memory...
Done (t=0.38s)
creating index...
index created!
Training images: 117266
Validation images: 5000
Data transform¶
We can read an (image, label, segm) tuple from the training dataset:
train_image, train_label, train_segm = train_dataset[6]
bboxes = train_label[:, :4]
cids = train_label[:, 4:5]
print('image:', train_image.shape)
print('bboxes:', bboxes.shape, 'class ids:', cids.shape)
# segm is a list of polygons which are arrays of points on the object boundary
print('masks', [[poly.shape for poly in polys] for polys in train_segm])
Out:
image: (500, 381, 3)
bboxes: (9, 4) class ids: (9, 1)
masks [[(95, 2)], [(32, 2)], [(31, 2)], [(50, 2)], [(54, 2)], [(13, 2)], [(24, 2)], [(10, 2), (15, 2)], [(21, 2)]]
Plot the image with boxes and labels:
from matplotlib import pyplot as plt
from gluoncv.utils import viz
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax = viz.plot_bbox(train_image, bboxes, labels=cids, class_names=train_dataset.classes, ax=ax)
plt.show()
To actually see the object segmentation, we need to convert polygons to masks
import numpy as np
from gluoncv.data.transforms import mask as tmask
width, height = train_image.shape[1], train_image.shape[0]
train_masks = np.stack([tmask.to_mask(polys, (width, height)) for polys in train_segm])
plt_image = viz.plot_mask(train_image, train_masks)
Now plot the image with boxes, labels and masks
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax = viz.plot_bbox(plt_image, bboxes, labels=cids, class_names=train_dataset.classes, ax=ax)
plt.show()
Data transforms, i.e. decoding and transformation, are identical to Faster R-CNN
with the exception of segmentation polygons as an additional input.
gluoncv.data.transforms.presets.rcnn.MaskRCNNDefaultTrainTransform
converts the segmentation polygons to binary segmentation mask.
gluoncv.data.transforms.presets.rcnn.MaskRCNNDefaultValTransform
ignores the segmentation polygons and returns image tensor and [im_height, im_width, im_scale]
.
from gluoncv.data.transforms import presets
from gluoncv import utils
from mxnet import nd
short, max_size = 600, 1000 # resize image to short side 600 px, but keep maximum length within 1000
train_transform = presets.rcnn.MaskRCNNDefaultTrainTransform(short, max_size)
val_transform = presets.rcnn.MaskRCNNDefaultValTransform(short, max_size)
utils.random.seed(233) # fix seed in this tutorial
apply transforms to train image
train_image2, train_label2, train_masks2 = train_transform(train_image, train_label, train_segm)
print('tensor shape:', train_image2.shape)
print('box and id shape:', train_label2.shape)
print('mask shape', train_masks2.shape)
Out:
tensor shape: (3, 787, 600)
box and id shape: (9, 5)
mask shape (9, 787, 600)
Images in tensor are distorted because they no longer sit in (0, 255) range. Let’s convert them back so we can see them clearly.
plt_image2 = train_image2.transpose((1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array(
(0.485, 0.456, 0.406))
plt_image2 = (plt_image2 * 255).asnumpy().astype('uint8')
The transform already converted polygons to masks and we plot them directly.
width, height = plt_image2.shape[1], plt_image2.shape[0]
plt_image2 = viz.plot_mask(plt_image2, train_masks2)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax = viz.plot_bbox(plt_image2, train_label2[:, :4],
labels=train_label2[:, 4:5],
class_names=train_dataset.classes,
ax=ax)
plt.show()
Data Loader¶
Data loader is identical to Faster R-CNN with the difference of mask input and output.
from gluoncv.data.batchify import Tuple, Append, MaskRCNNTrainBatchify
from mxnet.gluon.data import DataLoader
batch_size = 2 # for tutorial, we use smaller batch-size
num_workers = 0 # you can make it larger(if your CPU has more cores) to accelerate data loading
train_bfn = Tuple(*[Append() for _ in range(3)])
train_loader = DataLoader(train_dataset.transform(train_transform), batch_size, shuffle=True,
batchify_fn=train_bfn, last_batch='rollover', num_workers=num_workers)
val_bfn = Tuple(*[Append() for _ in range(2)])
val_loader = DataLoader(val_dataset.transform(val_transform), batch_size, shuffle=False,
batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
for ib, batch in enumerate(train_loader):
if ib > 3:
break
print('data 0:', batch[0][0].shape, 'label 0:', batch[1][0].shape, 'mask 0:', batch[2][0].shape)
print('data 1:', batch[0][1].shape, 'label 1:', batch[1][1].shape, 'mask 1:', batch[2][1].shape)
Out:
data 0: (1, 3, 600, 901) label 0: (1, 2, 5) mask 0: (1, 2, 600, 901)
data 1: (1, 3, 800, 600) label 1: (1, 1, 5) mask 1: (1, 1, 800, 600)
data 0: (1, 3, 798, 600) label 0: (1, 2, 5) mask 0: (1, 2, 798, 600)
data 1: (1, 3, 600, 600) label 1: (1, 18, 5) mask 1: (1, 18, 600, 600)
data 0: (1, 3, 600, 800) label 0: (1, 1, 5) mask 0: (1, 1, 600, 800)
data 1: (1, 3, 600, 600) label 1: (1, 5, 5) mask 1: (1, 5, 600, 600)
data 0: (1, 3, 600, 800) label 0: (1, 2, 5) mask 0: (1, 2, 600, 800)
data 1: (1, 3, 800, 600) label 1: (1, 21, 5) mask 1: (1, 21, 800, 600)
Mask RCNN Network¶
In GluonCV, Mask RCNN network gluoncv.model_zoo.MaskRCNN
is inherited from Faster RCNN network gluoncv.model_zoo.FasterRCNN
.
Gluon Model Zoo has some Mask RCNN pretrained networks. You can load your favorite one with one simple line of code:
Hint
To avoid downloading models in this tutorial, we set pretrained_base=False
,
in practice we usually want to load pre-trained imagenet models by setting
pretrained_base=True
.
from gluoncv import model_zoo
net = model_zoo.get_model('mask_rcnn_resnet50_v1b_coco', pretrained_base=False)
print(net)
Out:
MaskRCNN(
(features): HybridSequential(
(0): Conv2D(None -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
(2): Activation(relu)
(3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
(4): HybridSequential(
(0): BottleneckV1b(
(conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu3): Activation(relu)
(downsample): HybridSequential(
(0): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
)
)
(1): BottleneckV1b(
(conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu3): Activation(relu)
)
(2): BottleneckV1b(
(conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu3): Activation(relu)
)
)
(5): HybridSequential(
(0): BottleneckV1b(
(conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu3): Activation(relu)
(downsample): HybridSequential(
(0): Conv2D(None -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
)
)
(1): BottleneckV1b(
(conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu3): Activation(relu)
)
(2): BottleneckV1b(
(conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu3): Activation(relu)
)
(3): BottleneckV1b(
(conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu3): Activation(relu)
)
)
(6): HybridSequential(
(0): BottleneckV1b(
(conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
(relu3): Activation(relu)
(downsample): HybridSequential(
(0): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
)
)
(1): BottleneckV1b(
(conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
(relu3): Activation(relu)
)
(2): BottleneckV1b(
(conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
(relu3): Activation(relu)
)
(3): BottleneckV1b(
(conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
(relu3): Activation(relu)
)
(4): BottleneckV1b(
(conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
(relu3): Activation(relu)
)
(5): BottleneckV1b(
(conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024)
(relu3): Activation(relu)
)
)
)
(top_features): HybridSequential(
(0): HybridSequential(
(0): BottleneckV1b(
(conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
(relu3): Activation(relu)
(downsample): HybridSequential(
(0): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
)
)
(1): BottleneckV1b(
(conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
(relu3): Activation(relu)
)
(2): BottleneckV1b(
(conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu1): Activation(relu)
(conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512)
(relu2): Activation(relu)
(conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048)
(relu3): Activation(relu)
)
)
)
(class_predictor): Dense(None -> 81, linear)
(box_predictor): Dense(None -> 320, linear)
(cls_decoder): MultiPerClassDecoder(
)
(box_decoder): NormalizedBoxCenterDecoder(
(corner_to_center): BBoxCornerToCenter(
)
)
(rpn): RPN(
(anchor_generator): RPNAnchorGenerator(
)
(conv1): HybridSequential(
(0): Conv2D(None -> 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Activation(relu)
)
(score): Conv2D(None -> 15, kernel_size=(1, 1), stride=(1, 1))
(loc): Conv2D(None -> 60, kernel_size=(1, 1), stride=(1, 1))
(region_proposer): RPNProposal(
(_box_to_center): BBoxCornerToCenter(
)
(_box_decoder): NormalizedBoxCenterDecoder(
(corner_to_center): BBoxCornerToCenter(
)
)
(_clipper): BBoxClipToImage(
)
)
)
(sampler): RCNNTargetSampler(
)
(mask): Mask(
(deconv): Conv2DTranspose(256 -> 0, kernel_size=(2, 2), stride=(2, 2))
(mask): Conv2D(None -> 80, kernel_size=(1, 1), stride=(1, 1))
)
(mask_target): MaskTargetGenerator(
)
)
Mask-RCNN has identical inputs but produces an additional output.
cids
are the class labels,
scores
are confidence scores of each prediction,
bboxes
are absolute coordinates of corresponding bounding boxes.
masks
are predicted segmentation masks corresponding to each bounding box
During training, an additional output is returned:
mask_preds
are per class masks predictions
in addition to cls_preds
, box_preds
.
from mxnet import autograd
with autograd.train_mode():
# this time we need ground-truth to generate high quality roi proposals during training
gt_box = mx.nd.zeros(shape=(1, 1, 4))
gt_label = mx.nd.zeros(shape=(1, 1, 1))
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = net(x, gt_box, gt_label)
Training losses¶
There are one additional losses in Mask-RCNN.
# the loss to penalize incorrect foreground/background prediction
rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
# the loss to penalize inaccurate anchor boxes
rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.) # == smoothl1
# the loss to penalize incorrect classification prediction.
rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
# and finally the loss to penalize inaccurate proposals
rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1
# the loss to penalize incorrect segmentation pixel prediction
rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
Training targets¶
RPN and RCNN training target are the same as in 06. Train Faster-RCNN end-to-end on PASCAL VOC.
We also push RPN targets computation to CPU workers, so network is passed to transforms
train_transform = presets.rcnn.MaskRCNNDefaultTrainTransform(short, max_size, net)
# return images, labels, masks, rpn_cls_targets, rpn_box_targets, rpn_box_masks loosely
batchify_fn = MaskRCNNTrainBatchify(net)
# For the next part, we only use batch size 1
batch_size = 1
train_loader = DataLoader(train_dataset.transform(train_transform), batch_size, shuffle=True,
batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
Mask targets are generated with the intermediate outputs after rcnn target is generated.
for ib, batch in enumerate(train_loader):
if ib > 0:
break
with autograd.train_mode():
for data, label, masks, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(*batch):
label = label.expand_dims(0)
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
# network forward
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = \
net(data.expand_dims(0), gt_box, gt_label)
# generate targets for mask head
roi = mx.nd.concat(
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1, 4))
m_cls_targets = mx.nd.concat(
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
matches = mx.nd.concat(
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches,
m_cls_targets)
print('data:', data.shape)
# box and class labels
print('box:', gt_box.shape)
print('label:', gt_label.shape)
# -1 marks ignored label
print('rpn cls label:', rpn_cls_targets.shape)
# mask out ignored box label
print('rpn box label:', rpn_box_targets.shape)
print('rpn box mask:', rpn_box_masks.shape)
# rcnn does not have ignored label
print('rcnn cls label:', cls_targets.shape)
# mask out ignored box label
print('rcnn box label:', box_targets.shape)
print('rcnn box mask:', box_masks.shape)
print('rcnn mask label:', mask_targets.shape)
print('rcnn mask mask:', mask_masks.shape)
Out:
data: (3, 831, 600)
box: (1, 1, 4)
label: (1, 1, 1)
rpn cls label: (1, 29640)
rpn box label: (1, 29640, 4)
rpn box mask: (1, 29640, 4)
rcnn cls label: (1, 128)
rcnn box label: (1, 32, 80, 4)
rcnn box mask: (1, 32, 80, 4)
rcnn mask label: (1, 32, 80, 14, 14)
rcnn mask mask: (1, 32, 80, 14, 14)
Training loop¶
After we have defined loss function and generated training targets, we can write the training loop.
for ib, batch in enumerate(train_loader):
if ib > 0:
break
with autograd.record():
for data, label, masks, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(*batch):
label = label.expand_dims(0)
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
# network forward
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = \
net(data.expand_dims(0), gt_box, gt_label)
# generate targets for mask head
roi = mx.nd.concat(
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1, 4))
m_cls_targets = mx.nd.concat(
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
matches = mx.nd.concat(
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches,
m_cls_targets)
# losses of rpn
rpn_score = rpn_score.squeeze(axis=-1)
num_rpn_pos = (rpn_cls_targets >= 0).sum()
rpn_loss1 = rpn_cls_loss(rpn_score, rpn_cls_targets,
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
rpn_loss2 = rpn_box_loss(rpn_box, rpn_box_targets,
rpn_box_masks) * rpn_box.size / num_rpn_pos
# losses of rcnn
num_rcnn_pos = (cls_targets >= 0).sum()
rcnn_loss1 = rcnn_cls_loss(cls_preds, cls_targets,
cls_targets >= 0) * cls_targets.size / cls_targets.shape[
0] / num_rcnn_pos
rcnn_loss2 = rcnn_box_loss(box_preds, box_targets, box_masks) * box_preds.size / \
box_preds.shape[0] / num_rcnn_pos
# loss of mask
mask_loss = rcnn_mask_loss(mask_preds, mask_targets, mask_masks) * mask_targets.size / \
mask_targets.shape[0] / mask_masks.sum()
# some standard gluon training steps:
# autograd.backward([rpn_loss1, rpn_loss2, rcnn_loss1, rcnn_loss2, mask_loss])
# trainer.step(batch_size)
Hint
Please checkout the full training script
for complete implementation.
References¶
- Girshick14
Ross Girshick and Jeff Donahue and Trevor Darrell and Jitendra Malik. Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation. CVPR 2014.
- Girshick15
Ross Girshick. Fast {R-CNN}. ICCV 2015.
- Ren15(1,2)
Shaoqing Ren and Kaiming He and Ross Girshick and Jian Sun. Faster {R-CNN}: Towards Real-Time Object Detection with Region Proposal Networks. NIPS 2015.
- He16
Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun. Deep Residual Learning for Image Recognition. CVPR 2016.
- Lin17
Tsung-Yi Lin and Piotr Dollár and Ross Girshick and Kaiming He and Bharath Hariharan and Serge Belongie. Feature Pyramid Networks for Object Detection. CVPR 2017.
- He17(1,2)
Kaiming He and Georgia Gkioxari and Piotr Dollár and and Ross Girshick. Mask {R-CNN}. ICCV 2017.
Total running time of the script: ( 6 minutes 6.824 seconds)