.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "build/examples_segmentation/train_psp.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_build_examples_segmentation_train_psp.py: 5. Train PSPNet on ADE20K Dataset ================================= This is a tutorial of training PSPNet on ADE20K dataset using Gluon Vison. The readers should have basic knowledge of deep learning and should be familiar with Gluon API. New users may first go through `A 60-minute Gluon Crash Course `_. You can `Start Training Now`_ or `Dive into Deep`_. Start Training Now ~~~~~~~~~~~~~~~~~~ .. hint:: Feel free to skip the tutorial because the training script is self-complete and ready to launch. :download:`Download Full Python Script: train.py<../../../scripts/segmentation/train.py>` Example training command:: CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ade20k --model psp --backbone resnet50 --syncbn --epochs 120 --lr 0.01 --checkname mycheckpoint For more training command options, please run ``python train.py -h`` Please checkout the `model_zoo <../model_zoo/index.html#semantic-segmentation>`_ for training commands of reproducing the pretrained model. Dive into Deep ~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 28-33 .. code-block:: default import numpy as np import mxnet as mx from mxnet import gluon, autograd import gluoncv .. GENERATED FROM PYTHON SOURCE LINES 34-47 Pyramid Scene Parsing Network ----------------------------- .. image:: https://hszhao.github.io/projects/pspnet/figures/pspnet.png :width: 80% :align: center (figure credit to `Zhao et al. `_ ) Pyramid Scene Parsing Network (PSPNet) [Zhao17]_ exploit the capability of global context information by different-regionbased context aggregation through the pyramid pooling module. .. GENERATED FROM PYTHON SOURCE LINES 50-85 PSPNet Model ------------ A Pyramid Pooling Module is built on top of FCN, which combines multiple scale features with different receptive field sizes. It pools the featuremaps into different sizes and then concatenating together after upsampling. The Pyramid Pooling Module is defined as:: class _PyramidPooling(HybridBlock): def __init__(self, in_channels, **kwargs): super(_PyramidPooling, self).__init__() out_channels = int(in_channels/4) with self.name_scope(): self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs) self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs) self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs) self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs) def pool(self, F, x, size): return F.contrib.AdaptiveAvgPooling2D(x, output_size=size) def upsample(self, F, x, h, w): return F.contrib.BilinearResize2D(x, height=h, width=w) def hybrid_forward(self, F, x): _, _, h, w = x.shape feat1 = self.upsample(F, self.conv1(self.pool(F, x, 1)), h, w) feat2 = self.upsample(F, self.conv2(self.pool(F, x, 2)), h, w) feat3 = self.upsample(F, self.conv3(self.pool(F, x, 3)), h, w) feat4 = self.upsample(F, self.conv4(self.pool(F, x, 4)), h, w) return F.concat(x, feat1, feat2, feat3, feat4, dim=1) PSPNet model is provided in :class:`gluoncv.model_zoo.PSPNet`. To get PSP model using ResNet50 base network for ADE20K dataset: .. GENERATED FROM PYTHON SOURCE LINES 85-88 .. code-block:: default model = gluoncv.model_zoo.get_psp(dataset='ade20k', backbone='resnet50', pretrained=False) print(model) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none self.crop_size 480 PSPNet( (conv1): HybridSequential( (0): Conv2D(3 -> 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (2): Activation(relu) (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (5): Activation(relu) (6): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu): Activation(relu) (maxpool): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (layer1): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(128 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) ) ) (1): BottleneckV1b( (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) ) ) (layer2): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(256 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(256 -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) ) ) (1): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) ) (layer3): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(512 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(512 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) ) ) (1): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (4): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (5): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) ) (layer4): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(1024 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(1024 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) ) ) (1): BottleneckV1b( (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) ) ) (head): _PSPHead( (psp): _PyramidPooling( (conv1): HybridSequential( (0): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (2): Activation(relu) ) (conv2): HybridSequential( (0): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (2): Activation(relu) ) (conv3): HybridSequential( (0): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (2): Activation(relu) ) (conv4): HybridSequential( (0): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (2): Activation(relu) ) ) (block): HybridSequential( (0): Conv2D(4096 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (2): Activation(relu) (3): Dropout(p = 0.1, axes=()) (4): Conv2D(512 -> 150, kernel_size=(1, 1), stride=(1, 1)) ) ) (auxlayer): _FCNHead( (block): HybridSequential( (0): Conv2D(1024 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) (3): Dropout(p = 0.1, axes=()) (4): Conv2D(256 -> 150, kernel_size=(1, 1), stride=(1, 1)) ) ) ) .. GENERATED FROM PYTHON SOURCE LINES 89-93 Dataset and Data Augmentation ----------------------------- image transform for color normalization .. GENERATED FROM PYTHON SOURCE LINES 93-99 .. code-block:: default from mxnet.gluon.data.vision import transforms input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) .. GENERATED FROM PYTHON SOURCE LINES 100-102 We provide semantic segmentation datasets in :class:`gluoncv.data`. For example, we can easily get the ADE20K dataset: .. GENERATED FROM PYTHON SOURCE LINES 102-111 .. code-block:: default trainset = gluoncv.data.ADE20KSegmentation(split='train', transform=input_transform) print('Training images:', len(trainset)) # set batch_size = 2 for toy example batch_size = 2 # Create Training Loader train_data = gluon.data.DataLoader( trainset, batch_size, shuffle=True, last_batch='rollover', num_workers=batch_size) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Training images: 20210 .. GENERATED FROM PYTHON SOURCE LINES 112-121 For data augmentation, we follow the standard data augmentation routine to transform the input image and the ground truth label map synchronously. (*Note that "nearest" mode upsample are applied to the label maps to avoid messing up the boundaries.*) We first randomly scale the input image from 0.5 to 2.0 times, then rotate the image from -10 to 10 degrees, and crop the image with padding if needed. Finally a random Gaussian blurring is applied. Random pick one example for visualization: .. GENERATED FROM PYTHON SOURCE LINES 121-134 .. code-block:: default import random from datetime import datetime random.seed(datetime.now()) idx = random.randint(0, len(trainset)) img, mask = trainset[idx] from gluoncv.utils.viz import get_color_pallete, DeNormalize # get color pallete for visualize mask mask = get_color_pallete(mask.asnumpy(), dataset='ade20k') mask.save('mask.png') # denormalize the image img = DeNormalize([.485, .456, .406], [.229, .224, .225])(img) img = np.transpose((img.asnumpy()*255).astype(np.uint8), (1, 2, 0)) .. GENERATED FROM PYTHON SOURCE LINES 135-136 Plot the image and mask .. GENERATED FROM PYTHON SOURCE LINES 136-150 .. code-block:: default from matplotlib import pyplot as plt import matplotlib.image as mpimg # subplot 1 for img fig = plt.figure() fig.add_subplot(1,2,1) plt.imshow(img) # subplot 2 for the mask mmask = mpimg.imread('mask.png') fig.add_subplot(1,2,2) plt.imshow(mmask) # display plt.show() .. image-sg:: /build/examples_segmentation/images/sphx_glr_train_psp_001.png :alt: train psp :srcset: /build/examples_segmentation/images/sphx_glr_train_psp_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 151-160 Training Details ---------------- - Training Losses: We apply a standard per-pixel Softmax Cross Entropy Loss to train PSPNet. Additionally, an Auxiliary Loss as in PSPNet [Zhao17]_ at Stage 3 can be enabled when training with command ``--aux``. This will create an additional FCN "head" after Stage 3. .. GENERATED FROM PYTHON SOURCE LINES 160-163 .. code-block:: default from gluoncv.loss import MixSoftmaxCrossEntropyLoss criterion = MixSoftmaxCrossEntropyLoss(aux=True) .. GENERATED FROM PYTHON SOURCE LINES 164-171 - Learning Rate and Scheduling: We use different learning rate for PSP "head" and the base network. For the PSP "head", we use :math:`10\times` base learning rate, because those layers are learned from scratch. We use a poly-like learning rate scheduler for FCN training, provided in :class:`gluoncv.utils.LRScheduler`. The learning rate is given by :math:`lr = base_lr \times (1-iter)^{power}` .. GENERATED FROM PYTHON SOURCE LINES 171-174 .. code-block:: default lr_scheduler = gluoncv.utils.LRScheduler(mode='poly', base_lr=0.001, nepochs=50, iters_per_epoch=len(train_data), power=0.9) .. GENERATED FROM PYTHON SOURCE LINES 175-176 - Dataparallel for multi-gpu training, using cpu for demo only .. GENERATED FROM PYTHON SOURCE LINES 176-181 .. code-block:: default from gluoncv.utils.parallel import * ctx_list = [mx.cpu(0)] model = DataParallelModel(model, ctx_list) criterion = DataParallelCriterion(criterion, ctx_list) .. GENERATED FROM PYTHON SOURCE LINES 182-183 - Create SGD solver .. GENERATED FROM PYTHON SOURCE LINES 183-191 .. code-block:: default kv = mx.kv.create('local') optimizer = gluon.Trainer(model.module.collect_params(), 'sgd', {'lr_scheduler': lr_scheduler, 'wd':0.0001, 'momentum': 0.9, 'multi_precision': True}, kvstore = kv) .. GENERATED FROM PYTHON SOURCE LINES 192-195 The training loop ----------------- .. GENERATED FROM PYTHON SOURCE LINES 195-213 .. code-block:: default train_loss = 0.0 epoch = 0 for i, (data, target) in enumerate(train_data): with autograd.record(True): outputs = model(data) losses = criterion(outputs, target) mx.nd.waitall() autograd.backward(losses) optimizer.step(batch_size) for loss in losses: train_loss += loss.asnumpy()[0] / len(losses) print('Epoch %d, batch %d, training loss %.3f'%(epoch, i, train_loss/(i+1))) # just demo for 2 iters if i > 1: print('Terminated for this demo...') break .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Epoch 0, batch 0, training loss 2.556 Epoch 0, batch 1, training loss 4.315 Epoch 0, batch 2, training loss 5.093 Terminated for this demo... .. GENERATED FROM PYTHON SOURCE LINES 214-226 You can `Start Training Now`_. References ---------- .. [Long15] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. \ "Fully convolutional networks for semantic segmentation." \ Proceedings of the IEEE conference on computer vision and pattern recognition. 2015. .. [Zhao17] Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. \ "Pyramid scene parsing network." IEEE Conf. on Computer Vision and Pattern Recognition (CVPR). 2017. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 51.820 seconds) .. _sphx_glr_download_build_examples_segmentation_train_psp.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: train_psp.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: train_psp.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_