Source code for gluoncv.model_zoo.monodepthv2.pose_decoder

"""Decoder module of Monodepth2
Code partially borrowed from
# Copyright Niantic 2019. Patent Pending. All rights reserved.
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.

from __future__ import absolute_import, division, print_function

from collections import OrderedDict
import mxnet as mx
import mxnet.gluon.nn as nn

[docs]class PoseDecoder(nn.HybridBlock): r"""Decoder of Monodepth2 PoseNet Parameters ---------- num_ch_enc : list The channels number of encoder. num_input_features: int The number of input sequences. 1 for depth encoder, larger than 1 for pose encoder. (Default: 2) num_frames_to_predict_for: int The number of output pose between frames; If None, it equals num_input_features - 1. (Default: 2) stride: int The stride number for Conv in pose decoder. (Default: 1) """ def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=2, stride=1): super(PoseDecoder, self).__init__() self.num_ch_enc = num_ch_enc self.num_input_features = num_input_features if num_frames_to_predict_for is None: num_frames_to_predict_for = num_input_features - 1 self.num_frames_to_predict_for = num_frames_to_predict_for self.convs = OrderedDict() self.convs[("squeeze")] = nn.Conv2D( in_channels=self.num_ch_enc[-1], channels=256, kernel_size=1) self.convs[("pose", 0)] = nn.Conv2D( in_channels=num_input_features * 256, channels=256, kernel_size=3, strides=stride, padding=1) self.convs[("pose", 1)] = nn.Conv2D( in_channels=256, channels=256, kernel_size=3, strides=stride, padding=1) self.convs[("pose", 2)] = nn.Conv2D( in_channels=256, channels=6 * num_frames_to_predict_for, kernel_size=1) # register blocks for k in self.convs: self.register_child(self.convs[k]) = nn.HybridSequential()*list(self.convs.values()))
[docs] def hybrid_forward(self, F, input_features): # pylint: disable=unused-argument, missing-function-docstring last_features = [f[-1] for f in input_features] cat_features = [F.relu(self.convs["squeeze"](f)) for f in last_features] cat_features = F.concat(*cat_features, dim=1) out = cat_features for i in range(3): out = self.convs[("pose", i)](out) if i != 2: out = F.relu(out) out = out.mean(3).mean(2) out = 0.01 * out.reshape(shape=(-1, self.num_frames_to_predict_for, 1, 6)) axisangle = F.slice(out, begin=(0, 0, 0, 0), end=(None, 2, 1, 3)) translation = F.slice(out, begin=(0, 0, 0, 3), end=(None, 2, 1, 6)) return axisangle, translation
def predict(self, input_features): # pylint: disable=unused-argument, missing-function-docstring last_features = [f[-1] for f in input_features] cat_features = [mx.nd.relu()(self.convs["squeeze"](f)) for f in last_features] cat_features = mx.nd.concat(*cat_features, dim=1) out = cat_features for i in range(3): out = self.convs[("pose", i)](out) if i != 2: out = mx.nd.relu(out) out = out.mean(3).mean(2) out = 0.01 * out.reshape(-1, self.num_frames_to_predict_for, 1, 6) axisangle = out[..., :3] translation = out[..., 3:] return axisangle, translation