4. Train FCN on Pascal VOC Dataset

This is a semantic segmentation tutorial using Gluon CV toolkit, a step-by-step example. The readers should have basic knowledge of deep learning and should be familiar with Gluon API. New users may first go through A 60-minute Gluon Crash Course. You can Start Training Now or Dive into Deep.

Start Training Now

Hint

Feel free to skip the tutorial because the training script is self-complete and ready to launch.

Download Full Python Script: train.py

Example training command:

# First training on augmented set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_aug --model fcn --backbone resnet50 --lr 0.001 --checkname mycheckpoint
# Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_voc --model fcn --backbone resnet50 --lr 0.0001 --checkname mycheckpoint --resume runs/pascal_aug/fcn/mycheckpoint/checkpoint.params

For more training command options, please run python train.py -h Please checkout the model_zoo for training commands of reproducing the pretrained model.

Dive into Deep

import numpy as np
import mxnet as mx
from mxnet import gluon, autograd

import gluoncv

Fully Convolutional Network

https://cdn-images-1.medium.com/max/800/1*wRkj6lsQ5ckExB5BoYkrZg.png

(figure credit to Long et al. )

State-of-the-art approaches of semantic segmentation are typically based on Fully Convolutional Network (FCN) [Long15]. The key idea of a fully convolutional network is that it is “fully convolutional”, which means it does not have any fully connected layers. Therefore, the network can accept arbitrary input size and make dense per-pixel predictions. Base/Encoder network is typically pre-trained on ImageNet, because the features learned from diverse set of images contain rich contextual information, which can be beneficial for semantic segmentation.

Model Dilation

The adaption of base network pre-trained on ImageNet leads to loss spatial resolution, because these networks are originally designed for classification task. Following standard implementation in recent works of semantic segmentation, we apply dilation strategy to the stage 3 and stage 4 of the pre-trained networks, which produces stride of 8 featuremaps (models are provided in gluoncv.model_zoo.ResNetV1b). Visualization of dilated/atrous convoution (figure credit to conv_arithmetic ):

https://raw.githubusercontent.com/vdumoulin/conv_arithmetic/master/gif/dilation.gif

Loading a dilated ResNet50 is simply:

Out:

Downloading /root/.mxnet/models/resnet50_v1b-0ecdba34.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet50_v1b-0ecdba34.zip...

  0%|          | 0/55344 [00:00<?, ?KB/s]
  0%|          | 101/55344 [00:00<01:09, 796.07KB/s]
  1%|          | 513/55344 [00:00<00:24, 2250.73KB/s]
  4%|3         | 2177/55344 [00:00<00:07, 7236.37KB/s]
 14%|#3        | 7726/55344 [00:00<00:02, 23667.96KB/s]
 25%|##5       | 13938/55344 [00:00<00:01, 36251.20KB/s]
 40%|###9      | 22106/55344 [00:00<00:00, 50722.84KB/s]
 52%|#####1    | 28625/55344 [00:00<00:00, 55242.07KB/s]
 66%|######6   | 36797/55344 [00:00<00:00, 61772.39KB/s]
 81%|########1 | 45028/55344 [00:00<00:00, 67942.76KB/s]
 95%|#########5| 52697/55344 [00:01<00:00, 70567.58KB/s]
100%|##########| 55344/55344 [00:01<00:00, 49603.53KB/s]

For convenience, we provide a base model for semantic segmentation, which automatically load the pre-trained dilated ResNet gluoncv.model_zoo.segbase.SegBaseModel with a convenient method base_forward(input) to get stage 3 & 4 featuremaps:

basemodel = gluoncv.model_zoo.segbase.SegBaseModel(nclass=10, aux=False)
x = mx.nd.random.uniform(shape=(1, 3, 224, 224))
c3, c4 = basemodel.base_forward(x)
print('Shapes of c3 & c4 featuremaps are ', c3.shape, c4.shape)

Out:

Downloading /root/.mxnet/models/resnet50_v1s-25a187fa.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet50_v1s-25a187fa.zip...

  0%|          | 0/57417 [00:00<?, ?KB/s]
  6%|6         | 3660/57417 [00:00<00:01, 29403.33KB/s]
 20%|##        | 11642/57417 [00:00<00:00, 56341.89KB/s]
 33%|###2      | 18905/57417 [00:00<00:00, 63389.25KB/s]
 44%|####4     | 25430/57417 [00:00<00:00, 62913.29KB/s]
 58%|#####7    | 33031/57417 [00:00<00:00, 67478.28KB/s]
 70%|#######   | 40306/57417 [00:00<00:00, 69230.31KB/s]
 82%|########2 | 47291/57417 [00:00<00:00, 67710.66KB/s]
 96%|#########6| 55134/57417 [00:00<00:00, 71001.66KB/s]
57418KB [00:00, 63212.84KB/s]
Shapes of c3 & c4 featuremaps are  (1, 1024, 28, 28) (1, 2048, 28, 28)

FCN Model

We build a fully convolutional “head” on top of the base network, the FCNHead is defined as:

class _FCNHead(HybridBlock):
    def __init__(self, in_channels, channels, norm_layer, **kwargs):
        super(_FCNHead, self).__init__()
        with self.name_scope():
            self.block = nn.HybridSequential()
            inter_channels = in_channels // 4
            with self.block.name_scope():
                self.block.add(nn.Conv2D(in_channels=in_channels, channels=inter_channels,
                                         kernel_size=3, padding=1))
                self.block.add(norm_layer(in_channels=inter_channels))
                self.block.add(nn.Activation('relu'))
                self.block.add(nn.Dropout(0.1))
                self.block.add(nn.Conv2D(in_channels=inter_channels, channels=channels,
                                         kernel_size=1))

def hybrid_forward(self, F, x):
    return self.block(x)

FCN model is provided in gluoncv.model_zoo.FCN. To get FCN model using ResNet50 base network for Pascal VOC dataset:

model = gluoncv.model_zoo.get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False)
print(model)

Out:

FCN(
  (conv1): HybridSequential(
    (0): Conv2D(3 -> 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
    (2): Activation(relu)
    (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
    (5): Activation(relu)
    (6): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
  (relu): Activation(relu)
  (maxpool): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
  (layer1): HybridSequential(
    (0): BottleneckV1b(
      (conv1): Conv2D(128 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (relu1): Activation(relu)
      (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (relu2): Activation(relu)
      (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu3): Activation(relu)
      (downsample): HybridSequential(
        (0): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      )
    )
    (1): BottleneckV1b(
      (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (relu1): Activation(relu)
      (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (relu2): Activation(relu)
      (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu3): Activation(relu)
    )
    (2): BottleneckV1b(
      (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (relu1): Activation(relu)
      (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (relu2): Activation(relu)
      (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu3): Activation(relu)
    )
  )
  (layer2): HybridSequential(
    (0): BottleneckV1b(
      (conv1): Conv2D(256 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu1): Activation(relu)
      (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu2): Activation(relu)
      (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu3): Activation(relu)
      (downsample): HybridSequential(
        (0): Conv2D(256 -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      )
    )
    (1): BottleneckV1b(
      (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu1): Activation(relu)
      (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu2): Activation(relu)
      (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu3): Activation(relu)
    )
    (2): BottleneckV1b(
      (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu1): Activation(relu)
      (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu2): Activation(relu)
      (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu3): Activation(relu)
    )
    (3): BottleneckV1b(
      (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu1): Activation(relu)
      (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (relu2): Activation(relu)
      (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu3): Activation(relu)
    )
  )
  (layer3): HybridSequential(
    (0): BottleneckV1b(
      (conv1): Conv2D(512 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu1): Activation(relu)
      (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu2): Activation(relu)
      (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024)
      (relu3): Activation(relu)
      (downsample): HybridSequential(
        (0): Conv2D(512 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024)
      )
    )
    (1): BottleneckV1b(
      (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu1): Activation(relu)
      (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu2): Activation(relu)
      (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024)
      (relu3): Activation(relu)
    )
    (2): BottleneckV1b(
      (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu1): Activation(relu)
      (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu2): Activation(relu)
      (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024)
      (relu3): Activation(relu)
    )
    (3): BottleneckV1b(
      (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu1): Activation(relu)
      (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu2): Activation(relu)
      (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024)
      (relu3): Activation(relu)
    )
    (4): BottleneckV1b(
      (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu1): Activation(relu)
      (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu2): Activation(relu)
      (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024)
      (relu3): Activation(relu)
    )
    (5): BottleneckV1b(
      (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu1): Activation(relu)
      (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (relu2): Activation(relu)
      (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024)
      (relu3): Activation(relu)
    )
  )
  (layer4): HybridSequential(
    (0): BottleneckV1b(
      (conv1): Conv2D(1024 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu1): Activation(relu)
      (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu2): Activation(relu)
      (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048)
      (relu3): Activation(relu)
      (downsample): HybridSequential(
        (0): Conv2D(1024 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048)
      )
    )
    (1): BottleneckV1b(
      (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu1): Activation(relu)
      (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu2): Activation(relu)
      (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048)
      (relu3): Activation(relu)
    )
    (2): BottleneckV1b(
      (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu1): Activation(relu)
      (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
      (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (relu2): Activation(relu)
      (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048)
      (relu3): Activation(relu)
    )
  )
  (head): _FCNHead(
    (block): HybridSequential(
      (0): Conv2D(2048 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (2): Activation(relu)
      (3): Dropout(p = 0.1, axes=())
      (4): Conv2D(512 -> 21, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (auxlayer): _FCNHead(
    (block): HybridSequential(
      (0): Conv2D(1024 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (2): Activation(relu)
      (3): Dropout(p = 0.1, axes=())
      (4): Conv2D(256 -> 21, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

Dataset and Data Augmentation

image transform for color normalization

from mxnet.gluon.data.vision import transforms
input_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])

We provide semantic segmentation datasets in gluoncv.data. For example, we can easily get the Pascal VOC 2012 dataset:

trainset = gluoncv.data.VOCSegmentation(split='train', transform=input_transform)
print('Training images:', len(trainset))
# set batch_size = 2 for toy example
batch_size = 2
# Create Training Loader
train_data = gluon.data.DataLoader(
    trainset, batch_size, shuffle=True, last_batch='rollover',
    num_workers=batch_size)

Out:

Training images: 2913

For data augmentation, we follow the standard data augmentation routine to transform the input image and the ground truth label map synchronously. (Note that “nearest” mode upsample are applied to the label maps to avoid messing up the boundaries.) We first randomly scale the input image from 0.5 to 2.0 times, then rotate the image from -10 to 10 degrees, and crop the image with padding if needed. Finally a random Gaussian blurring is applied.

Random pick one example for visualization:

import random
from datetime import datetime
random.seed(datetime.now())
idx = random.randint(0, len(trainset))
img, mask = trainset[idx]
from gluoncv.utils.viz import get_color_pallete, DeNormalize
# get color pallete for visualize mask
mask = get_color_pallete(mask.asnumpy(), dataset='pascal_voc')
mask.save('mask.png')
# denormalize the image
img = DeNormalize([.485, .456, .406], [.229, .224, .225])(img)
img = np.transpose((img.asnumpy()*255).astype(np.uint8), (1, 2, 0))

Plot the image and mask

from matplotlib import pyplot as plt
import matplotlib.image as mpimg
# subplot 1 for img
fig = plt.figure()
fig.add_subplot(1,2,1)

plt.imshow(img)
# subplot 2 for the mask
mmask = mpimg.imread('mask.png')
fig.add_subplot(1,2,2)
plt.imshow(mmask)
# display
plt.show()
train fcn

Training Details

  • Training Losses:

    We apply a standard per-pixel Softmax Cross Entropy Loss to train FCN. For Pascal VOC dataset, we ignore the loss from boundary class (number 22). Additionally, an Auxiliary Loss as in PSPNet [Zhao17] at Stage 3 can be enabled when training with command --aux. This will create an additional FCN “head” after Stage 3.

from gluoncv.loss import MixSoftmaxCrossEntropyLoss
criterion = MixSoftmaxCrossEntropyLoss(aux=True)
  • Learning Rate and Scheduling:

    We use different learning rate for FCN “head” and the base network. For the FCN “head”, we use \(10\times\) base learning rate, because those layers are learned from scratch. We use a poly-like learning rate scheduler for FCN training, provided in gluoncv.utils.LRScheduler. The learning rate is given by \(lr = base_lr \times (1-iter)^{power}\)

lr_scheduler = gluoncv.utils.LRScheduler('poly', base_lr=0.001,
                                         nepochs=50, iters_per_epoch=len(train_data), power=0.9)
  • Dataparallel for multi-gpu training, using cpu for demo only

from gluoncv.utils.parallel import *
ctx_list = [mx.cpu(0)]
model = DataParallelModel(model, ctx_list)
criterion = DataParallelCriterion(criterion, ctx_list)
  • Create SGD solver

kv = mx.kv.create('device')
optimizer = gluon.Trainer(model.module.collect_params(), 'sgd',
                          {'lr_scheduler': lr_scheduler,
                           'wd':0.0001,
                           'momentum': 0.9,
                           'multi_precision': True},
                          kvstore = kv)

The training loop

train_loss = 0.0
epoch = 0
for i, (data, target) in enumerate(train_data):
    with autograd.record(True):
        outputs = model(data)
        losses = criterion(outputs, target)
        mx.nd.waitall()
        autograd.backward(losses)
    optimizer.step(batch_size)
    for loss in losses:
        train_loss += loss.asnumpy()[0] / len(losses)
    print('Epoch %d, batch %d, training loss %.3f'%(epoch, i, train_loss/(i+1)))
    # just demo for 2 iters
    if i > 1:
        print('Terminated for this demo...')
        break

Out:

Epoch 0, batch 0, training loss 4.147
Epoch 0, batch 1, training loss 3.964
Epoch 0, batch 2, training loss 3.723
Terminated for this demo...

You can Start Training Now.

References

Long15

Long, Jonathan, Evan Shelhamer, and Trevor Darrell. “Fully convolutional networks for semantic segmentation.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

Zhao17

Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. “Pyramid scene parsing network.” IEEE Conf. on Computer Vision and Pattern Recognition (CVPR). 2017.

Total running time of the script: ( 0 minutes 38.190 seconds)

Gallery generated by Sphinx-Gallery