.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "build/examples_segmentation/train_fcn.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_build_examples_segmentation_train_fcn.py: 4. Train FCN on Pascal VOC Dataset ===================================== This is a semantic segmentation tutorial using Gluon CV toolkit, a step-by-step example. The readers should have basic knowledge of deep learning and should be familiar with Gluon API. New users may first go through `A 60-minute Gluon Crash Course `_. You can `Start Training Now`_ or `Dive into Deep`_. Start Training Now ~~~~~~~~~~~~~~~~~~ .. hint:: Feel free to skip the tutorial because the training script is self-complete and ready to launch. :download:`Download Full Python Script: train.py<../../../scripts/segmentation/train.py>` Example training command:: # First training on augmented set CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_aug --model fcn --backbone resnet50 --lr 0.001 --checkname mycheckpoint # Finetuning on original set CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_voc --model fcn --backbone resnet50 --lr 0.0001 --checkname mycheckpoint --resume runs/pascal_aug/fcn/mycheckpoint/checkpoint.params For more training command options, please run ``python train.py -h`` Please checkout the `model_zoo <../model_zoo/index.html#semantic-segmentation>`_ for training commands of reproducing the pretrained model. Dive into Deep ~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 31-36 .. code-block:: default import numpy as np import mxnet as mx from mxnet import gluon, autograd import gluoncv .. GENERATED FROM PYTHON SOURCE LINES 37-56 Fully Convolutional Network --------------------------- .. image:: https://cdn-images-1.medium.com/max/800/1*wRkj6lsQ5ckExB5BoYkrZg.png :width: 70% :align: center (figure credit to `Long et al. `_ ) State-of-the-art approaches of semantic segmentation are typically based on Fully Convolutional Network (FCN) [Long15]_. The key idea of a fully convolutional network is that it is "fully convolutional", which means it does not have any fully connected layers. Therefore, the network can accept arbitrary input size and make dense per-pixel predictions. Base/Encoder network is typically pre-trained on ImageNet, because the features learned from diverse set of images contain rich contextual information, which can be beneficial for semantic segmentation. .. GENERATED FROM PYTHON SOURCE LINES 59-78 Model Dilation -------------- The adaption of base network pre-trained on ImageNet leads to loss spatial resolution, because these networks are originally designed for classification task. Following standard implementation in recent works of semantic segmentation, we apply dilation strategy to the stage 3 and stage 4 of the pre-trained networks, which produces stride of 8 featuremaps (models are provided in :class:`gluoncv.model_zoo.ResNetV1b`). Visualization of dilated/atrous convoution (figure credit to `conv_arithmetic `_ ): .. image:: https://raw.githubusercontent.com/vdumoulin/conv_arithmetic/master/gif/dilation.gif :width: 40% :align: center Loading a dilated ResNet50 is simply: .. GENERATED FROM PYTHON SOURCE LINES 78-80 .. code-block:: default pretrained_net = gluoncv.model_zoo.resnet50_v1b(pretrained=True) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading /root/.mxnet/models/resnet50_v1b-0ecdba34.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet50_v1b-0ecdba34.zip... 0%| | 0/55344 [00:00 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (2): Activation(relu) (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (5): Activation(relu) (6): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu): Activation(relu) (maxpool): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (layer1): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(128 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) ) ) (1): BottleneckV1b( (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) ) ) (layer2): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(256 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(256 -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) ) ) (1): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) ) (layer3): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(512 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(512 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) ) ) (1): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (4): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (5): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) ) (layer4): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(1024 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(1024 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) ) ) (1): BottleneckV1b( (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) ) ) (head): _FCNHead( (block): HybridSequential( (0): Conv2D(2048 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (2): Activation(relu) (3): Dropout(p = 0.1, axes=()) (4): Conv2D(512 -> 21, kernel_size=(1, 1), stride=(1, 1)) ) ) (auxlayer): _FCNHead( (block): HybridSequential( (0): Conv2D(1024 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) (3): Dropout(p = 0.1, axes=()) (4): Conv2D(256 -> 21, kernel_size=(1, 1), stride=(1, 1)) ) ) ) .. GENERATED FROM PYTHON SOURCE LINES 121-125 Dataset and Data Augmentation ----------------------------- image transform for color normalization .. GENERATED FROM PYTHON SOURCE LINES 125-131 .. code-block:: default from mxnet.gluon.data.vision import transforms input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) .. GENERATED FROM PYTHON SOURCE LINES 132-134 We provide semantic segmentation datasets in :class:`gluoncv.data`. For example, we can easily get the Pascal VOC 2012 dataset: .. GENERATED FROM PYTHON SOURCE LINES 134-143 .. code-block:: default trainset = gluoncv.data.VOCSegmentation(split='train', transform=input_transform) print('Training images:', len(trainset)) # set batch_size = 2 for toy example batch_size = 2 # Create Training Loader train_data = gluon.data.DataLoader( trainset, batch_size, shuffle=True, last_batch='rollover', num_workers=batch_size) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Training images: 2913 .. GENERATED FROM PYTHON SOURCE LINES 144-153 For data augmentation, we follow the standard data augmentation routine to transform the input image and the ground truth label map synchronously. (*Note that "nearest" mode upsample are applied to the label maps to avoid messing up the boundaries.*) We first randomly scale the input image from 0.5 to 2.0 times, then rotate the image from -10 to 10 degrees, and crop the image with padding if needed. Finally a random Gaussian blurring is applied. Random pick one example for visualization: .. GENERATED FROM PYTHON SOURCE LINES 153-166 .. code-block:: default import random from datetime import datetime random.seed(datetime.now()) idx = random.randint(0, len(trainset)) img, mask = trainset[idx] from gluoncv.utils.viz import get_color_pallete, DeNormalize # get color pallete for visualize mask mask = get_color_pallete(mask.asnumpy(), dataset='pascal_voc') mask.save('mask.png') # denormalize the image img = DeNormalize([.485, .456, .406], [.229, .224, .225])(img) img = np.transpose((img.asnumpy()*255).astype(np.uint8), (1, 2, 0)) .. GENERATED FROM PYTHON SOURCE LINES 167-168 Plot the image and mask .. GENERATED FROM PYTHON SOURCE LINES 168-182 .. code-block:: default from matplotlib import pyplot as plt import matplotlib.image as mpimg # subplot 1 for img fig = plt.figure() fig.add_subplot(1,2,1) plt.imshow(img) # subplot 2 for the mask mmask = mpimg.imread('mask.png') fig.add_subplot(1,2,2) plt.imshow(mmask) # display plt.show() .. image-sg:: /build/examples_segmentation/images/sphx_glr_train_fcn_001.png :alt: train fcn :srcset: /build/examples_segmentation/images/sphx_glr_train_fcn_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 183-193 Training Details ---------------- - Training Losses: We apply a standard per-pixel Softmax Cross Entropy Loss to train FCN. For Pascal VOC dataset, we ignore the loss from boundary class (number 22). Additionally, an Auxiliary Loss as in PSPNet [Zhao17]_ at Stage 3 can be enabled when training with command ``--aux``. This will create an additional FCN "head" after Stage 3. .. GENERATED FROM PYTHON SOURCE LINES 193-196 .. code-block:: default from gluoncv.loss import MixSoftmaxCrossEntropyLoss criterion = MixSoftmaxCrossEntropyLoss(aux=True) .. GENERATED FROM PYTHON SOURCE LINES 197-204 - Learning Rate and Scheduling: We use different learning rate for FCN "head" and the base network. For the FCN "head", we use :math:`10\times` base learning rate, because those layers are learned from scratch. We use a poly-like learning rate scheduler for FCN training, provided in :class:`gluoncv.utils.LRScheduler`. The learning rate is given by :math:`lr = base_lr \times (1-iter)^{power}` .. GENERATED FROM PYTHON SOURCE LINES 204-207 .. code-block:: default lr_scheduler = gluoncv.utils.LRScheduler('poly', base_lr=0.001, nepochs=50, iters_per_epoch=len(train_data), power=0.9) .. GENERATED FROM PYTHON SOURCE LINES 208-209 - Dataparallel for multi-gpu training, using cpu for demo only .. GENERATED FROM PYTHON SOURCE LINES 209-214 .. code-block:: default from gluoncv.utils.parallel import * ctx_list = [mx.cpu(0)] model = DataParallelModel(model, ctx_list) criterion = DataParallelCriterion(criterion, ctx_list) .. GENERATED FROM PYTHON SOURCE LINES 215-216 - Create SGD solver .. GENERATED FROM PYTHON SOURCE LINES 216-224 .. code-block:: default kv = mx.kv.create('device') optimizer = gluon.Trainer(model.module.collect_params(), 'sgd', {'lr_scheduler': lr_scheduler, 'wd':0.0001, 'momentum': 0.9, 'multi_precision': True}, kvstore = kv) .. GENERATED FROM PYTHON SOURCE LINES 225-228 The training loop ----------------- .. GENERATED FROM PYTHON SOURCE LINES 228-246 .. code-block:: default train_loss = 0.0 epoch = 0 for i, (data, target) in enumerate(train_data): with autograd.record(True): outputs = model(data) losses = criterion(outputs, target) mx.nd.waitall() autograd.backward(losses) optimizer.step(batch_size) for loss in losses: train_loss += loss.asnumpy()[0] / len(losses) print('Epoch %d, batch %d, training loss %.3f'%(epoch, i, train_loss/(i+1))) # just demo for 2 iters if i > 1: print('Terminated for this demo...') break .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Epoch 0, batch 0, training loss 4.147 Epoch 0, batch 1, training loss 3.964 Epoch 0, batch 2, training loss 3.723 Terminated for this demo... .. GENERATED FROM PYTHON SOURCE LINES 247-259 You can `Start Training Now`_. References ---------- .. [Long15] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. \ "Fully convolutional networks for semantic segmentation." \ Proceedings of the IEEE conference on computer vision and pattern recognition. 2015. .. [Zhao17] Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. \ "Pyramid scene parsing network." IEEE Conf. on Computer Vision and Pattern Recognition (CVPR). 2017. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 38.190 seconds) .. _sphx_glr_download_build_examples_segmentation_train_fcn.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: train_fcn.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: train_fcn.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_