Table Of Contents
Table Of Contents

Source code for gluoncv.model_zoo.simple_pose.simple_pose_resnet

# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------

# coding: utf-8
# pylint: disable=missing-docstring,unused-argument,arguments-differ

from __future__ import division

__all__ = ['get_simple_pose_resnet', 'SimplePoseResNet',
           'simple_pose_resnet18_v1b',
           'simple_pose_resnet50_v1b', 'simple_pose_resnet101_v1b',
           'simple_pose_resnet152_v1b',
           'simple_pose_resnet50_v1d', 'simple_pose_resnet101_v1d',
           'simple_pose_resnet152_v1d']

from mxnet.context import cpu
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet import initializer
import gluoncv as gcv

[docs]class SimplePoseResNet(HybridBlock): def __init__(self, base_name='resnet50_v1b', pretrained_base=False, pretrained_ctx=cpu(), num_joints=17, num_deconv_layers=3, num_deconv_filters=(256, 256, 256), num_deconv_kernels=(4, 4, 4), final_conv_kernel=1, deconv_with_bias=False, **kwargs): super(SimplePoseResNet, self).__init__(**kwargs) from ..model_zoo import get_model base_network = get_model(base_name, pretrained=pretrained_base, ctx=pretrained_ctx, norm_layer=gcv.nn.BatchNormCudnnOff) self.resnet = nn.HybridSequential() if base_name.endswith('v1'): for layer in ['features']: self.resnet.add(getattr(base_network, layer)) else: for layer in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4']: self.resnet.add(getattr(base_network, layer)) self.deconv_with_bias = deconv_with_bias # used for deconv layers self.deconv_layers = self._make_deconv_layer( num_deconv_layers, num_deconv_filters, num_deconv_kernels, ) self.final_layer = nn.Conv2D( channels=num_joints, kernel_size=final_conv_kernel, strides=1, padding=1 if final_conv_kernel == 3 else 0, weight_initializer=initializer.Normal(0.001), bias_initializer=initializer.Zero() ) def _get_deconv_cfg(self, deconv_kernel): if deconv_kernel == 4: padding = 1 output_padding = 0 elif deconv_kernel == 3: padding = 1 output_padding = 1 elif deconv_kernel == 2: padding = 0 output_padding = 0 return deconv_kernel, padding, output_padding def _make_deconv_layer(self, num_layers, num_filters, num_kernels): assert num_layers == len(num_filters), \ 'ERROR: num_deconv_layers is different from len(num_deconv_filters)' assert num_layers == len(num_kernels), \ 'ERROR: num_deconv_layers is different from len(num_deconv_filters)' layer = nn.HybridSequential(prefix='') with layer.name_scope(): for i in range(num_layers): kernel, padding, output_padding = \ self._get_deconv_cfg(num_kernels[i]) planes = num_filters[i] layer.add( nn.Conv2DTranspose( channels=planes, kernel_size=kernel, strides=2, padding=padding, output_padding=output_padding, use_bias=self.deconv_with_bias, weight_initializer=initializer.Normal(0.001), bias_initializer=initializer.Zero())) layer.add(gcv.nn.BatchNormCudnnOff(gamma_initializer=initializer.One(), beta_initializer=initializer.Zero())) layer.add(nn.Activation('relu')) self.inplanes = planes return layer
[docs] def hybrid_forward(self, F, x): x = self.resnet(x) x = self.deconv_layers(x) x = self.final_layer(x) return x
def get_simple_pose_resnet(base_name, pretrained=False, ctx=cpu(), root='~/.mxnet/models', **kwargs): net = SimplePoseResNet(base_name, **kwargs) if pretrained: from ..model_store import get_model_file net.load_parameters(get_model_file('simple_pose_%s'%(base_name), tag=pretrained, root=root), ctx=ctx) return net
[docs]def simple_pose_resnet18_v1b(**kwargs): r"""ResNet-18 backbone model from `"Simple Baselines for Human Pose Estimation and Tracking" <https://arxiv.org/abs/1804.06208>`_ paper. Parameters ---------- pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. ctx : Context, default CPU The context in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_simple_pose_resnet('resnet18_v1b', **kwargs)
[docs]def simple_pose_resnet50_v1b(**kwargs): r"""ResNet-50 backbone model from `"Simple Baselines for Human Pose Estimation and Tracking" <https://arxiv.org/abs/1804.06208>`_ paper. Parameters ---------- pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. ctx : Context, default CPU The context in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_simple_pose_resnet('resnet50_v1b', **kwargs)
[docs]def simple_pose_resnet101_v1b(**kwargs): r"""ResNet-101 backbone model from `"Simple Baselines for Human Pose Estimation and Tracking" <https://arxiv.org/abs/1804.06208>`_ paper. Parameters ---------- pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. ctx : Context, default CPU The context in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_simple_pose_resnet('resnet101_v1b', **kwargs)
[docs]def simple_pose_resnet152_v1b(**kwargs): r"""ResNet-152 backbone model from `"Simple Baselines for Human Pose Estimation and Tracking" <https://arxiv.org/abs/1804.06208>`_ paper. Parameters ---------- pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. ctx : Context, default CPU The context in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_simple_pose_resnet('resnet152_v1b', **kwargs)
[docs]def simple_pose_resnet50_v1d(**kwargs): r"""ResNet-50-d backbone model from `"Simple Baselines for Human Pose Estimation and Tracking" <https://arxiv.org/abs/1804.06208>`_ paper. Parameters ---------- pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. ctx : Context, default CPU The context in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_simple_pose_resnet('resnet50_v1d', **kwargs)
[docs]def simple_pose_resnet101_v1d(**kwargs): r"""ResNet-101-d backbone model from `"Simple Baselines for Human Pose Estimation and Tracking" <https://arxiv.org/abs/1804.06208>`_ paper. Parameters ---------- pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. ctx : Context, default CPU The context in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_simple_pose_resnet('resnet101_v1d', **kwargs)
[docs]def simple_pose_resnet152_v1d(**kwargs): r"""ResNet-152-d backbone model from `"Simple Baselines for Human Pose Estimation and Tracking" <https://arxiv.org/abs/1804.06208>`_ paper. Parameters ---------- pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. ctx : Context, default CPU The context in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_simple_pose_resnet('resnet152_v1d', **kwargs)