.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "build/examples_instance/train_mask_rcnn_coco.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_build_examples_instance_train_mask_rcnn_coco.py: 2. Train Mask RCNN end-to-end on MS COCO =========================================== This tutorial goes through the steps for training a Mask R-CNN [He17]_ instance segmentation model provided by GluonCV. Mask R-CNN is an extension to the Faster R-CNN [Ren15]_ object detection model. As such, this tutorial is also an extension to :doc:`../examples_detection/train_faster_rcnn_voc`. We will focus on the extra work on top of Faster R-CNN to show how to use GluonCV components to construct a Mask R-CNN model. It is highly recommended to read the original papers [Girshick14]_, [Girshick15]_, [Ren15]_, [He17]_ to learn more about the ideas behind Mask R-CNN. Appendix from [He16]_ and experiment detail from [Lin17]_ may also be useful reference. .. hint:: Please first go through this :ref:`sphx_glr_build_examples_datasets_mscoco.py` tutorial to setup MSCOCO dataset on your disk. .. hint:: You can skip the rest of this tutorial and start training your Mask RCNN model right away by downloading this script: :download:`Download train_mask_rcnn.py<../../../scripts/instance/mask_rcnn/train_mask_rcnn.py>` Example usage: Train a default resnet50_v1b model with COCO dataset on GPU 0: .. code-block:: bash python train_mask_rcnn.py --gpus 0 Train on GPU 0,1,2,3: .. code-block:: bash python train_mask_rcnn.py --gpus 0,1,2,3 Check the supported arguments: .. code-block:: bash python train_mask_rcnn.py --help .. GENERATED FROM PYTHON SOURCE LINES 51-56 Dataset ------- Make sure COCO dataset has been set up on your disk. Then, we are ready to load training and validation images. .. GENERATED FROM PYTHON SOURCE LINES 56-69 .. code-block:: default from gluoncv.data import COCOInstance # typically we use train2017 (i.e. train2014 + minival35k) split as training data # COCO dataset actually has images without any objects annotated, # which must be skipped during training to prevent empty labels train_dataset = COCOInstance(splits='instances_train2017', skip_empty=True) # and val2014 (i.e. minival5k) test as validation data val_dataset = COCOInstance(splits='instances_val2017', skip_empty=False) print('Training images:', len(train_dataset)) print('Validation images:', len(val_dataset)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none loading annotations into memory... Done (t=13.42s) creating index... index created! loading annotations into memory... Done (t=0.38s) creating index... index created! Training images: 117266 Validation images: 5000 .. GENERATED FROM PYTHON SOURCE LINES 70-73 Data transform -------------- We can read an (image, label, segm) tuple from the training dataset: .. GENERATED FROM PYTHON SOURCE LINES 73-81 .. code-block:: default train_image, train_label, train_segm = train_dataset[6] bboxes = train_label[:, :4] cids = train_label[:, 4:5] print('image:', train_image.shape) print('bboxes:', bboxes.shape, 'class ids:', cids.shape) # segm is a list of polygons which are arrays of points on the object boundary print('masks', [[poly.shape for poly in polys] for polys in train_segm]) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none image: (500, 381, 3) bboxes: (9, 4) class ids: (9, 1) masks [[(95, 2)], [(32, 2)], [(31, 2)], [(50, 2)], [(54, 2)], [(13, 2)], [(24, 2)], [(10, 2), (15, 2)], [(21, 2)]] .. GENERATED FROM PYTHON SOURCE LINES 82-83 Plot the image with boxes and labels: .. GENERATED FROM PYTHON SOURCE LINES 83-91 .. code-block:: default from matplotlib import pyplot as plt from gluoncv.utils import viz fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(1, 1, 1) ax = viz.plot_bbox(train_image, bboxes, labels=cids, class_names=train_dataset.classes, ax=ax) plt.show() .. image-sg:: /build/examples_instance/images/sphx_glr_train_mask_rcnn_coco_001.png :alt: train mask rcnn coco :srcset: /build/examples_instance/images/sphx_glr_train_mask_rcnn_coco_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 92-93 To actually see the object segmentation, we need to convert polygons to masks .. GENERATED FROM PYTHON SOURCE LINES 93-100 .. code-block:: default import numpy as np from gluoncv.data.transforms import mask as tmask width, height = train_image.shape[1], train_image.shape[0] train_masks = np.stack([tmask.to_mask(polys, (width, height)) for polys in train_segm]) plt_image = viz.plot_mask(train_image, train_masks) .. GENERATED FROM PYTHON SOURCE LINES 101-102 Now plot the image with boxes, labels and masks .. GENERATED FROM PYTHON SOURCE LINES 102-107 .. code-block:: default fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(1, 1, 1) ax = viz.plot_bbox(plt_image, bboxes, labels=cids, class_names=train_dataset.classes, ax=ax) plt.show() .. image-sg:: /build/examples_instance/images/sphx_glr_train_mask_rcnn_coco_002.png :alt: train mask rcnn coco :srcset: /build/examples_instance/images/sphx_glr_train_mask_rcnn_coco_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 108-114 Data transforms, i.e. decoding and transformation, are identical to Faster R-CNN with the exception of segmentation polygons as an additional input. :py:class:`gluoncv.data.transforms.presets.rcnn.MaskRCNNDefaultTrainTransform` converts the segmentation polygons to binary segmentation mask. :py:class:`gluoncv.data.transforms.presets.rcnn.MaskRCNNDefaultValTransform` ignores the segmentation polygons and returns image tensor and ``[im_height, im_width, im_scale]``. .. GENERATED FROM PYTHON SOURCE LINES 114-118 .. code-block:: default from gluoncv.data.transforms import presets from gluoncv import utils from mxnet import nd .. GENERATED FROM PYTHON SOURCE LINES 119-123 .. code-block:: default short, max_size = 600, 1000 # resize image to short side 600 px, but keep maximum length within 1000 train_transform = presets.rcnn.MaskRCNNDefaultTrainTransform(short, max_size) val_transform = presets.rcnn.MaskRCNNDefaultValTransform(short, max_size) .. GENERATED FROM PYTHON SOURCE LINES 124-126 .. code-block:: default utils.random.seed(233) # fix seed in this tutorial .. GENERATED FROM PYTHON SOURCE LINES 127-128 apply transforms to train image .. GENERATED FROM PYTHON SOURCE LINES 128-133 .. code-block:: default train_image2, train_label2, train_masks2 = train_transform(train_image, train_label, train_segm) print('tensor shape:', train_image2.shape) print('box and id shape:', train_label2.shape) print('mask shape', train_masks2.shape) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none tensor shape: (3, 787, 600) box and id shape: (9, 5) mask shape (9, 787, 600) .. GENERATED FROM PYTHON SOURCE LINES 134-136 Images in tensor are distorted because they no longer sit in (0, 255) range. Let's convert them back so we can see them clearly. .. GENERATED FROM PYTHON SOURCE LINES 136-140 .. code-block:: default plt_image2 = train_image2.transpose((1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array( (0.485, 0.456, 0.406)) plt_image2 = (plt_image2 * 255).asnumpy().astype('uint8') .. GENERATED FROM PYTHON SOURCE LINES 141-142 The transform already converted polygons to masks and we plot them directly. .. GENERATED FROM PYTHON SOURCE LINES 142-153 .. code-block:: default width, height = plt_image2.shape[1], plt_image2.shape[0] plt_image2 = viz.plot_mask(plt_image2, train_masks2) fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(1, 1, 1) ax = viz.plot_bbox(plt_image2, train_label2[:, :4], labels=train_label2[:, 4:5], class_names=train_dataset.classes, ax=ax) plt.show() .. image-sg:: /build/examples_instance/images/sphx_glr_train_mask_rcnn_coco_003.png :alt: train mask rcnn coco :srcset: /build/examples_instance/images/sphx_glr_train_mask_rcnn_coco_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 154-157 Data Loader ----------- Data loader is identical to Faster R-CNN with the difference of mask input and output. .. GENERATED FROM PYTHON SOURCE LINES 157-177 .. code-block:: default from gluoncv.data.batchify import Tuple, Append, MaskRCNNTrainBatchify from mxnet.gluon.data import DataLoader batch_size = 2 # for tutorial, we use smaller batch-size num_workers = 0 # you can make it larger(if your CPU has more cores) to accelerate data loading train_bfn = Tuple(*[Append() for _ in range(3)]) train_loader = DataLoader(train_dataset.transform(train_transform), batch_size, shuffle=True, batchify_fn=train_bfn, last_batch='rollover', num_workers=num_workers) val_bfn = Tuple(*[Append() for _ in range(2)]) val_loader = DataLoader(val_dataset.transform(val_transform), batch_size, shuffle=False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers) for ib, batch in enumerate(train_loader): if ib > 3: break print('data 0:', batch[0][0].shape, 'label 0:', batch[1][0].shape, 'mask 0:', batch[2][0].shape) print('data 1:', batch[0][1].shape, 'label 1:', batch[1][1].shape, 'mask 1:', batch[2][1].shape) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none data 0: (1, 3, 600, 901) label 0: (1, 2, 5) mask 0: (1, 2, 600, 901) data 1: (1, 3, 800, 600) label 1: (1, 1, 5) mask 1: (1, 1, 800, 600) data 0: (1, 3, 798, 600) label 0: (1, 2, 5) mask 0: (1, 2, 798, 600) data 1: (1, 3, 600, 600) label 1: (1, 18, 5) mask 1: (1, 18, 600, 600) data 0: (1, 3, 600, 800) label 0: (1, 1, 5) mask 0: (1, 1, 600, 800) data 1: (1, 3, 600, 600) label 1: (1, 5, 5) mask 1: (1, 5, 600, 600) data 0: (1, 3, 600, 800) label 0: (1, 2, 5) mask 0: (1, 2, 600, 800) data 1: (1, 3, 800, 600) label 1: (1, 21, 5) mask 1: (1, 21, 800, 600) .. GENERATED FROM PYTHON SOURCE LINES 178-191 Mask RCNN Network ------------------- In GluonCV, Mask RCNN network :py:class:`gluoncv.model_zoo.MaskRCNN` is inherited from Faster RCNN network :py:class:`gluoncv.model_zoo.FasterRCNN`. `Gluon Model Zoo <../../model_zoo/index.html>`__ has some Mask RCNN pretrained networks. You can load your favorite one with one simple line of code: .. hint:: To avoid downloading models in this tutorial, we set ``pretrained_base=False``, in practice we usually want to load pre-trained imagenet models by setting ``pretrained_base=True``. .. GENERATED FROM PYTHON SOURCE LINES 191-196 .. code-block:: default from gluoncv import model_zoo net = model_zoo.get_model('mask_rcnn_resnet50_v1b_coco', pretrained_base=False) print(net) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none MaskRCNN( (features): HybridSequential( (0): Conv2D(None -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64) (2): Activation(relu) (3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (4): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) ) ) (1): BottleneckV1b( (conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(None -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu3): Activation(relu) ) ) (5): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(None -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) ) ) (1): BottleneckV1b( (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu3): Activation(relu) ) ) (6): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024) ) ) (1): BottleneckV1b( (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024) (relu3): Activation(relu) ) (4): BottleneckV1b( (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024) (relu3): Activation(relu) ) (5): BottleneckV1b( (conv1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=1024) (relu3): Activation(relu) ) ) ) (top_features): HybridSequential( (0): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048) ) ) (1): BottleneckV1b( (conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(None -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(None -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=True, in_channels=2048) (relu3): Activation(relu) ) ) ) (class_predictor): Dense(None -> 81, linear) (box_predictor): Dense(None -> 320, linear) (cls_decoder): MultiPerClassDecoder( ) (box_decoder): NormalizedBoxCenterDecoder( (corner_to_center): BBoxCornerToCenter( ) ) (rpn): RPN( (anchor_generator): RPNAnchorGenerator( ) (conv1): HybridSequential( (0): Conv2D(None -> 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Activation(relu) ) (score): Conv2D(None -> 15, kernel_size=(1, 1), stride=(1, 1)) (loc): Conv2D(None -> 60, kernel_size=(1, 1), stride=(1, 1)) (region_proposer): RPNProposal( (_box_to_center): BBoxCornerToCenter( ) (_box_decoder): NormalizedBoxCenterDecoder( (corner_to_center): BBoxCornerToCenter( ) ) (_clipper): BBoxClipToImage( ) ) ) (sampler): RCNNTargetSampler( ) (mask): Mask( (deconv): Conv2DTranspose(256 -> 0, kernel_size=(2, 2), stride=(2, 2)) (mask): Conv2D(None -> 80, kernel_size=(1, 1), stride=(1, 1)) ) (mask_target): MaskTargetGenerator( ) ) .. GENERATED FROM PYTHON SOURCE LINES 197-202 Mask-RCNN has identical inputs but produces an additional output. ``cids`` are the class labels, ``scores`` are confidence scores of each prediction, ``bboxes`` are absolute coordinates of corresponding bounding boxes. ``masks`` are predicted segmentation masks corresponding to each bounding box .. GENERATED FROM PYTHON SOURCE LINES 202-208 .. code-block:: default import mxnet as mx x = mx.nd.zeros(shape=(1, 3, 600, 800)) net.initialize() cids, scores, bboxes, masks = net(x) .. GENERATED FROM PYTHON SOURCE LINES 209-212 During training, an additional output is returned: ``mask_preds`` are per class masks predictions in addition to ``cls_preds``, ``box_preds``. .. GENERATED FROM PYTHON SOURCE LINES 212-221 .. code-block:: default from mxnet import autograd with autograd.train_mode(): # this time we need ground-truth to generate high quality roi proposals during training gt_box = mx.nd.zeros(shape=(1, 1, 4)) gt_label = mx.nd.zeros(shape=(1, 1, 1)) cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \ cls_targets, box_targets, box_masks, indices = net(x, gt_box, gt_label) .. GENERATED FROM PYTHON SOURCE LINES 222-225 Training losses ---------------- There are one additional losses in Mask-RCNN. .. GENERATED FROM PYTHON SOURCE LINES 225-237 .. code-block:: default # the loss to penalize incorrect foreground/background prediction rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) # the loss to penalize inaccurate anchor boxes rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.) # == smoothl1 # the loss to penalize incorrect classification prediction. rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss() # and finally the loss to penalize inaccurate proposals rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1 # the loss to penalize incorrect segmentation pixel prediction rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) .. GENERATED FROM PYTHON SOURCE LINES 238-241 Training targets ---------------- RPN and RCNN training target are the same as in :doc:`../examples_detection/train_faster_rcnn_voc`. .. GENERATED FROM PYTHON SOURCE LINES 243-244 We also push RPN targets computation to CPU workers, so network is passed to transforms .. GENERATED FROM PYTHON SOURCE LINES 244-252 .. code-block:: default train_transform = presets.rcnn.MaskRCNNDefaultTrainTransform(short, max_size, net) # return images, labels, masks, rpn_cls_targets, rpn_box_targets, rpn_box_masks loosely batchify_fn = MaskRCNNTrainBatchify(net) # For the next part, we only use batch size 1 batch_size = 1 train_loader = DataLoader(train_dataset.transform(train_transform), batch_size, shuffle=True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers) .. GENERATED FROM PYTHON SOURCE LINES 253-254 Mask targets are generated with the intermediate outputs after rcnn target is generated. .. GENERATED FROM PYTHON SOURCE LINES 254-298 .. code-block:: default for ib, batch in enumerate(train_loader): if ib > 0: break with autograd.train_mode(): for data, label, masks, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(*batch): label = label.expand_dims(0) gt_label = label[:, :, 4:5] gt_box = label[:, :, :4] # network forward cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \ cls_targets, box_targets, box_masks, indices = \ net(data.expand_dims(0), gt_box, gt_label) # generate targets for mask head roi = mx.nd.concat( *[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \ .reshape((indices.shape[0], -1, 4)) m_cls_targets = mx.nd.concat( *[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \ .reshape((indices.shape[0], -1)) matches = mx.nd.concat( *[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \ .reshape((indices.shape[0], -1)) mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches, m_cls_targets) print('data:', data.shape) # box and class labels print('box:', gt_box.shape) print('label:', gt_label.shape) # -1 marks ignored label print('rpn cls label:', rpn_cls_targets.shape) # mask out ignored box label print('rpn box label:', rpn_box_targets.shape) print('rpn box mask:', rpn_box_masks.shape) # rcnn does not have ignored label print('rcnn cls label:', cls_targets.shape) # mask out ignored box label print('rcnn box label:', box_targets.shape) print('rcnn box mask:', box_masks.shape) print('rcnn mask label:', mask_targets.shape) print('rcnn mask mask:', mask_masks.shape) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none data: (3, 831, 600) box: (1, 1, 4) label: (1, 1, 1) rpn cls label: (1, 29640) rpn box label: (1, 29640, 4) rpn box mask: (1, 29640, 4) rcnn cls label: (1, 128) rcnn box label: (1, 32, 80, 4) rcnn box mask: (1, 32, 80, 4) rcnn mask label: (1, 32, 80, 14, 14) rcnn mask mask: (1, 32, 80, 14, 14) .. GENERATED FROM PYTHON SOURCE LINES 299-302 Training loop ------------- After we have defined loss function and generated training targets, we can write the training loop. .. GENERATED FROM PYTHON SOURCE LINES 302-353 .. code-block:: default for ib, batch in enumerate(train_loader): if ib > 0: break with autograd.record(): for data, label, masks, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(*batch): label = label.expand_dims(0) gt_label = label[:, :, 4:5] gt_box = label[:, :, :4] # network forward cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors, \ cls_targets, box_targets, box_masks, indices = \ net(data.expand_dims(0), gt_box, gt_label) # generate targets for mask head roi = mx.nd.concat( *[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \ .reshape((indices.shape[0], -1, 4)) m_cls_targets = mx.nd.concat( *[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \ .reshape((indices.shape[0], -1)) matches = mx.nd.concat( *[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \ .reshape((indices.shape[0], -1)) mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches, m_cls_targets) # losses of rpn rpn_score = rpn_score.squeeze(axis=-1) num_rpn_pos = (rpn_cls_targets >= 0).sum() rpn_loss1 = rpn_cls_loss(rpn_score, rpn_cls_targets, rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos rpn_loss2 = rpn_box_loss(rpn_box, rpn_box_targets, rpn_box_masks) * rpn_box.size / num_rpn_pos # losses of rcnn num_rcnn_pos = (cls_targets >= 0).sum() rcnn_loss1 = rcnn_cls_loss(cls_preds, cls_targets, cls_targets >= 0) * cls_targets.size / cls_targets.shape[ 0] / num_rcnn_pos rcnn_loss2 = rcnn_box_loss(box_preds, box_targets, box_masks) * box_preds.size / \ box_preds.shape[0] / num_rcnn_pos # loss of mask mask_loss = rcnn_mask_loss(mask_preds, mask_targets, mask_masks) * mask_targets.size / \ mask_targets.shape[0] / mask_masks.sum() # some standard gluon training steps: # autograd.backward([rpn_loss1, rpn_loss2, rcnn_loss1, rcnn_loss2, mask_loss]) # trainer.step(batch_size) .. GENERATED FROM PYTHON SOURCE LINES 354-357 .. hint:: Please checkout the full :download:`training script <../../../scripts/instance/mask_rcnn/train_mask_rcnn.py>` for complete implementation. .. GENERATED FROM PYTHON SOURCE LINES 359-368 References ---------- .. [Girshick14] Ross Girshick and Jeff Donahue and Trevor Darrell and Jitendra Malik. Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation. CVPR 2014. .. [Girshick15] Ross Girshick. Fast {R-CNN}. ICCV 2015. .. [Ren15] Shaoqing Ren and Kaiming He and Ross Girshick and Jian Sun. Faster {R-CNN}: Towards Real-Time Object Detection with Region Proposal Networks. NIPS 2015. .. [He16] Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun. Deep Residual Learning for Image Recognition. CVPR 2016. .. [Lin17] Tsung-Yi Lin and Piotr Dollár and Ross Girshick and Kaiming He and Bharath Hariharan and Serge Belongie. Feature Pyramid Networks for Object Detection. CVPR 2017. .. [He17] Kaiming He and Georgia Gkioxari and Piotr Dollár and and Ross Girshick. Mask {R-CNN}. ICCV 2017. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 6 minutes 6.824 seconds) .. _sphx_glr_download_build_examples_instance_train_mask_rcnn_coco.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: train_mask_rcnn_coco.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: train_mask_rcnn_coco.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_