Table Of Contents
Table Of Contents

Source code for mxnet.gluon.block

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable= arguments-differ, too-many-lines
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']

import threading
import copy
import warnings
import re
from collections import OrderedDict

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
from .utils import _indent, _brief_print_list, HookHandle
from .utils import _check_same_symbol_type, _check_all_np_ndarrays
from .. import numpy_extension as _mx_npx
from .. import numpy as _mx_np
from .. util import is_np_array, np_shape, np_array

class _BlockScope(object):
    """Scope for collecting child `Block` s."""
    _current = threading.local()

    def __init__(self, block):
        self._block = block
        self._counter = {}
        self._old_scope = None
        self._name_scope = None

    def create(prefix, params, hint):
        """Creates prefix and params for new `Block`."""
        current = getattr(_BlockScope._current, "value", None)
        if current is None:
            if prefix is None:
                if not hasattr(_name.NameManager._current, "value"):
                    _name.NameManager._current.value = _name.NameManager()
                prefix = _name.NameManager._current.value.get(None, hint) + '_'
            if params is None:
                params = ParameterDict(prefix)
                params = ParameterDict(params.prefix, params)
            return prefix, params

        if prefix is None:
            count = current._counter.get(hint, 0)
            prefix = '%s%d_'%(hint, count)
            current._counter[hint] = count + 1
        if params is None:
            parent = current._block.params
            params = ParameterDict(parent.prefix+prefix, parent._shared)
            params = ParameterDict(params.prefix, params)
        return current._block.prefix+prefix, params

    def __enter__(self):
        if self._block._empty_prefix:
            return self
        self._old_scope = getattr(_BlockScope._current, "value", None)
        _BlockScope._current.value = self
        self._name_scope = _name.Prefix(self._block.prefix)
        return self

    def __exit__(self, ptype, value, trace):
        if self._block._empty_prefix:
        self._name_scope.__exit__(ptype, value, trace)
        self._name_scope = None
        _BlockScope._current.value = self._old_scope

def _flatten(args, inout_str):
    """Parse the arguments into a flattened list + an additional format array.
    The format array stores the structure of the original arguments to help reconstruct the inputs.

    args : NDArray, Symbol, or (nested) list of Symbol or NDArray
        We allow None inside the args.
    inout_str : str
        The name of the HybridBlock

    flat : list of Symbol or NDArray
        The flatten version of the input args.
    fmts : (nested) list of ints
        Stores the format information of the original structured args.
    if isinstance(args, NDArray):
        return [args], int(0)
    if isinstance(args, Symbol):
        length = len(args.list_outputs())
        length = length if length > 1 else 0
        return [args], int(length)
    if args is None:
        return [None], int(-1)

    assert isinstance(args, (list, tuple)), \
        "HybridBlock {} must be (nested) list of Symbol or NDArray, " \
        "but got {} of type {}".format(inout_str, str(args), str(type(args)))
    flat = []
    fmts = []
    for i in args:
        arg, fmt = _flatten(i, inout_str)
    return flat, fmts

def _regroup(args, fmt):
    """Reconstruct the structured arguments based on the flattened version.

    args : NDArray, Symbol, or (nested) list of Symbol or NDArray
        We allow None inside the args.
    fmt : (nested) list of ints
        Stores the format information of the original structured args.

    ret : NDArray, Symbol, or (nested) list of Symbol or NDArray

    def _merger(args, fmt):
        """Recursive call to merge the arguments"""
        if isinstance(fmt, int):
            if fmt < -1:
                raise ValueError("Unsupported encoded format {}.".format(fmt))
            if fmt == 0:
                return args[0], args[1:]
            if fmt == -1:
                if args[0] is not None:
                    raise ValueError('We do not support passing types that are not None'
                                     ' when the initial HybridBlock has received NoneType and'
                                     ' has been hybridized.'
                                     ' Received arg = {}, fmt = {}.'.format(args[0], fmt))
                return None, args[1:]
                return args[:fmt], args[fmt:]

        assert isinstance(args, (list, tuple)), \
            "HybridBlock output must be (nested) list of Symbol or NDArray, " \
            "but got {} of type {}".format(args, type(args))
        ret = []
        for i in fmt:
            res, args = _merger(args, i)
        return ret, args
    return _merger(args, fmt)[0]

class Block(object):
    """Base class for all neural network layers and models. Your models should
    subclass this class.

    :py:class:`Block` can be nested recursively in a tree structure. You can create and
    assign child :py:class:`Block` as regular attributes::

        from mxnet.gluon import Block, nn
        from mxnet import ndarray as F

        class Model(Block):
            def __init__(self, **kwargs):
                super(Model, self).__init__(**kwargs)
                # use name_scope to give child Blocks appropriate names.
                with self.name_scope():
                    self.dense0 = nn.Dense(20)
                    self.dense1 = nn.Dense(20)

            def forward(self, x):
                x = F.relu(self.dense0(x))
                return F.relu(self.dense1(x))

        model = Model()
        model(F.zeros((10, 10), ctx=mx.cpu(0)))

    Child :py:class:`Block` assigned this way will be registered and :py:meth:`collect_params`
    will collect their Parameters recursively. You can also manually register
    child blocks with :py:meth:`register_child`.

    prefix : str
        Prefix acts like a name space. All children blocks created in parent block's
        :py:meth:`name_scope` will have parent block's prefix in their name.
        Please refer to
        `naming tutorial <>`_
        for more info on prefix and naming.
    params : ParameterDict or None
        :py:class:`ParameterDict` for sharing weights with the new :py:class:`Block`. For example,
        if you want ``dense1`` to share ``dense0``'s weights, you can do::

            dense0 = nn.Dense(20)
            dense1 = nn.Dense(20, params=dense0.collect_params())
    def __init__(self, prefix=None, params=None):
        self._empty_prefix = prefix == ''
        self._prefix, self._params = _BlockScope.create(prefix, params, self._alias())
        self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix
        self._scope = _BlockScope(self)
        self._children = OrderedDict()
        self._reg_params = {}
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()

    def __repr__(self):
        s = '{name}(\n{modstr}\n)'
        modstr = '\n'.join(['  ({key}): {block}'.format(key=key,
                                                        block=_indent(block.__repr__(), 2))
                            for key, block in self.__dict__.items() if isinstance(block, Block)])
        return s.format(name=self.__class__.__name__, modstr=modstr)

    def __setattr__(self, name, value):
        """Registers parameters."""

        if hasattr(self, name):
            existing = getattr(self, name)
            if isinstance(existing, (Parameter, Block)) and not isinstance(value, type(existing)):
                raise TypeError('Changing attribute type for {name} from {type1} to {type2}' \
                                'is not allowed.'.format(
                                    name=name, type1=type(existing), type2=type(value)))

        if isinstance(value, Block):
            self.register_child(value, name)
        elif isinstance(value, Parameter):
            assert name not in self._reg_params, \
                "Overriding Parameter attribute %s is not allowed. " \
                "If you want to share parameters between blocks, please set " \
                "'params' at Block construction instead."
            self._reg_params[name] = value

        super(Block, self).__setattr__(name, value)

    def _check_container_with_block(self):
        children = set(self._children.values())
        def _find_unregistered_block_in_container(data):
            # Find whether a nested container structure contains Blocks
            if isinstance(data, (list, tuple)):
                for ele in data:
                    if _find_unregistered_block_in_container(ele):
                        return True
                return False
            elif isinstance(data, dict):
                for _, v in data.items():
                    if _find_unregistered_block_in_container(v):
                        return True
                return False
            elif isinstance(data, Block):
                return not data in children
                return False
        for k, v in self.__dict__.items():
            if isinstance(v, (list, tuple, dict)) and not (k.startswith('__') or k == '_children'):
                if _find_unregistered_block_in_container(v):
                    warnings.warn('"{name}" is an unregistered container with Blocks. '
                                  'Note that Blocks inside the list, tuple or dict will not be '
                                  'registered automatically. Make sure to register them using '
                                  'register_child() or switching to '
                                  'nn.Sequential/nn.HybridSequential instead. '
                                  .format(name=self.__class__.__name__ + "." + k), stacklevel=3)

    def _alias(self):
        return self.__class__.__name__.lower()

    def prefix(self):
        """Prefix of this :py:class:`Block`."""
        return self._prefix

    def name(self):
        """Name of this :py:class:`Block`, without '_' in the end."""
        return self._name

    def name_scope(self):
        """Returns a name space object managing a child :py:class:`Block` and parameter
        names. Should be used within a ``with`` statement::

            with self.name_scope():
                self.dense = nn.Dense(20)

        Please refer to
        `naming tutorial <>`_
        for more info on prefix and naming.
        return self._scope

    def params(self):
        """Returns this :py:class:`Block`'s parameter dictionary (does not include its
        children's parameters)."""
        return self._params

    def collect_params(self, select=None):
        """Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
        children's Parameters(default), also can returns the select :py:class:`ParameterDict`
        which match some given regular expressions.

        For example, collect the specified parameters in ['conv1_weight', 'conv1_bias', 'fc_weight',


        or collect all parameters whose names end with 'weight' or 'bias', this can be done
        using regular expressions::


        select : str
            regular expressions

        The selected :py:class:`ParameterDict`
        # We need to check here because blocks inside containers are not supported.
        ret = ParameterDict(self._params.prefix)
        if not select:
            pattern = re.compile(select)
            ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
        for cld in self._children.values():
        return ret

    def _collect_params_with_prefix(self, prefix=''):
        if prefix:
            prefix += '.'
        ret = {prefix + key : val for key, val in self._reg_params.items()}
        for name, child in self._children.items():
            ret.update(child._collect_params_with_prefix(prefix + name))
        return ret

    def save_parameters(self, filename):
        """Save parameters to file.

        Saved parameters can only be loaded with `load_parameters`. Note that this
        method only saves parameters, not model structure. If you want to save
        model structures, please use :py:meth:`HybridBlock.export`.

        filename : str
            Path to file.

        `Saving and Loading Gluon Models \
        params = self._collect_params_with_prefix()
        arg_dict = {key : val._reduce() for key, val in params.items()}
        save_fn = if is_np_array() else
        save_fn(filename, arg_dict)

    def save_params(self, filename):
        """[Deprecated] Please use save_parameters. Note that if you want load
        from SymbolBlock later, please use export instead.

        Save parameters to file.

        filename : str
            Path to file.
        warnings.warn("save_params is deprecated. Please use save_parameters. "
                      "Note that if you want load from SymbolBlock later, please "
                      "use export instead. For details, see "
            self.collect_params().save(filename, strip_prefix=self.prefix)
        except ValueError as e:
            raise ValueError('%s\nsave_params is deprecated. Using ' \
                              'save_parameters may resolve this error.'%e.message)

    def load_parameters(self, filename, ctx=None, allow_missing=False,
                        ignore_extra=False, cast_dtype=False, dtype_source='current'):
        """Load parameters from file previously saved by `save_parameters`.

        filename : str
            Path to parameter file.
        ctx : Context or list of Context, default cpu()
            Context(s) to initialize loaded parameters on.
        allow_missing : bool, default False
            Whether to silently skip loading parameters not represents in the file.
        ignore_extra : bool, default False
            Whether to silently ignore parameters from the file that are not
            present in this Block.
        cast_dtype : bool, default False
            Cast the data type of the NDArray loaded from the checkpoint to the dtype
            provided by the Parameter if any.
        dtype_source : str, default 'current'
            must be in {'current', 'saved'}
            Only valid if cast_dtype=True, specify the source of the dtype for casting
            the parameters
        `Saving and Loading Gluon Models \
        if is_np_array():
            # failure may happen when loading parameters saved as NDArrays within
            # NumPy semantics. Check the failure type and recover from it if it happens.
                loaded = _mx_npx.load(filename)
            except MXNetError as e:
                err_msg = str(e)
                if 'is_np_shape' in err_msg:
                    # Loading failure due to parameters saved without numpy semantics.
                    # Temporarily disable numpy semantics and load parameters. After it's
                    # done, resume the numpy semantics. This is fine because the cases
                    # numpy ndarray covers is a superset of the legacy ndarray's.
                    with np_array(False):
                        with np_shape(False):
                            loaded_nds = ndarray.load(filename)
                    assert isinstance(loaded_nds, dict),\
                        'expecting a dict type, got {}'.format(str(type(loaded_nds)))
                    loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds}
                    raise ValueError(err_msg)
            loaded = ndarray.load(filename)
        params = self._collect_params_with_prefix()
        if not loaded and not params:

        if not any('.' in i for i in loaded.keys()):
            # legacy loading
            del loaded
                filename, ctx, allow_missing, ignore_extra, self.prefix,
                cast_dtype=cast_dtype, dtype_source=dtype_source)

        if not allow_missing:
            for name in params.keys():
                assert name in loaded, \
                    "Parameter '%s' is missing in file '%s', which contains parameters: %s. " \
                    "Set allow_missing=True to ignore missing parameters."%(
                        name, filename, _brief_print_list(loaded.keys()))
        for name in loaded:
            if not ignore_extra and name not in params:
                raise ValueError(
                    "Parameter '%s' loaded from file '%s' is not present in ParameterDict, " \
                    "which contains parameters %s. Set ignore_extra=True to ignore. "%(
                        name, filename, _brief_print_list(self._params.keys())))
            if name in params:
                params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)

    def load_params(self, filename, ctx=None, allow_missing=False,
        """[Deprecated] Please use load_parameters.

        Load parameters from file.

        filename : str
            Path to parameter file.
        ctx : Context or list of Context, default cpu()
            Context(s) to initialize loaded parameters on.
        allow_missing : bool, default False
            Whether to silently skip loading parameters not represents in the file.
        ignore_extra : bool, default False
            Whether to silently ignore parameters from the file that are not
            present in this Block.
        warnings.warn("load_params is deprecated. Please use load_parameters.")
        self.load_parameters(filename, ctx, allow_missing, ignore_extra)

    def register_child(self, block, name=None):
        """Registers block as a child of self. :py:class:`Block` s assigned to self as
        attributes will be registered automatically."""
        if name is None:
            name = str(len(self._children))
        self._children[name] = block

    def register_forward_pre_hook(self, hook):
        r"""Registers a forward pre-hook on the block.

        The hook function is called immediately before :func:`forward`.
        It should not modify the input or output.

        hook : callable
            The forward hook function of form `hook(block, input) -> None`.

        handle = HookHandle()
        handle.attach(self._forward_pre_hooks, hook)
        return handle

    def register_forward_hook(self, hook):
        r"""Registers a forward hook on the block.

        The hook function is called immediately after :func:`forward`.
        It should not modify the input or output.

        hook : callable
            The forward hook function of form `hook(block, input, output) -> None`.

        handle = HookHandle()
        handle.attach(self._forward_hooks, hook)
        return handle

    def apply(self, fn):
        r"""Applies ``fn`` recursively to every child block as well as self.

        fn : callable
            Function to be applied to each submodule, of form `fn(block)`.

        this block
        for cld in self._children.values():
        return self

    def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
        """Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children.
        Equivalent to ``block.collect_params().initialize(...)``

        init : Initializer
            Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``.
            Otherwise, :py:meth:`Parameter.init` takes precedence.
        ctx : Context or list of Context
            Keeps a copy of Parameters on one or many context(s).
        verbose : bool, default False
            Whether to verbosely print out details on initialization.
        force_reinit : bool, default False
            Whether to force re-initialization if parameter is already initialized.
        self.collect_params().initialize(init, ctx, verbose, force_reinit)

    def hybridize(self, active=True, **kwargs):
        """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
        non-hybrid children.

        active : bool, default True
            Whether to turn hybrid on or off.
        static_alloc : bool, default False
            Statically allocate memory to improve speed. Memory usage may increase.
        static_shape : bool, default False
            Optimize for invariant input shapes between iterations. Must also
            set static_alloc to True. Change of input shapes is still allowed
            but slower.
        for cld in self._children.values():
            cld.hybridize(active, **kwargs)

    def cast(self, dtype):
        """Cast this Block to use another data type.

        dtype : str or numpy.dtype
            The new data type.
        for child in self._children.values():
        for _, param in self.params.items():

    def __call__(self, *args):
        """Calls forward. Only accepts positional arguments."""
        for hook in self._forward_pre_hooks.values():
            hook(self, args)

        out = self.forward(*args)

        for hook in self._forward_hooks.values():
            hook(self, args, out)
        if _mx_npx.is_np_array():
        return out

    def forward(self, *args):
        """Overrides to implement forward computation using :py:class:`NDArray`. Only
        accepts positional arguments.

        *args : list of NDArray
            Input tensors.
        # pylint: disable= invalid-name
        raise NotImplementedError

    def register_op_hook(self, callback, monitor_all=False):
        """Install callback monitor.

        callback : function
            Takes a string and a NDArrayHandle.
        monitor_all : bool, default False
            If true, monitor both input and output, otherwise monitor output only.
        for cld in self._children.values():
            cld.register_op_hook(callback, monitor_all)

    def summary(self, *inputs):
        """Print the summary of the model's output and parameters.

        The network must have been initialized, and must not have been hybridized.

        inputs : object
            Any input that the model supports. For any tensor in the input, only
            :class:`mxnet.ndarray.NDArray` is supported.
        summary = OrderedDict()
        seen = set()
        hooks = []

        def _get_shape_str(args):
            def flatten(args):
                if not isinstance(args, (list, tuple)):
                    return [args], int(0)
                flat = []
                fmts = []
                for i in args:
                    arg, fmt = flatten(i)
                return flat, fmts

            def regroup(args, fmt):
                if isinstance(fmt, int):
                    if fmt == 0:
                        return args[0], args[1:]
                    return args[:fmt], args[fmt:]
                ret = []
                for i in fmt:
                    res, args = regroup(args, i)
                return ret, args

            flat_args, fmts = flatten(args)
            flat_arg_shapes = [x.shape if isinstance(x, ndarray.NDArray) else x
                               for x in flat_args]
            shapes = regroup(flat_arg_shapes, fmts)[0]
            if isinstance(shapes, list):
                shape_str = str(shapes)[1:-1]
                shape_str = str(shapes)
            return shape_str.replace('L', '')

        def _register_summary_hook(block):
            assert not isinstance(block, HybridBlock) or not block._active, \
                    '"{}" must not be hybridized to print summary.'.format(
            def _summary_hook(block, _, outputs):
                class_name = block.__class__.__name__
                block_idx = len(summary) - 1

                m_key = '%s-%i' % (class_name, block_idx+1)
                summary[m_key] = OrderedDict()
                summary[m_key]['output_shape'] = _get_shape_str(outputs)

                params = 0
                summary[m_key]['trainable'] = 0
                summary[m_key]['shared'] = 0
                for p in block.params.values():
                    params +=
                    summary[m_key]['trainable'] += 0 if p.grad_req == 'null' else
                    if p in seen:
                        summary[m_key]['shared'] +=
                summary[m_key]['n_params'] = params

            from .nn.basic_layers import Sequential, HybridSequential
            if not isinstance(block, (Sequential, HybridSequential)):

        summary['Input'] = OrderedDict()
        summary['Input']['output_shape'] = _get_shape_str(inputs)
        summary['Input']['n_params'] = 0
        summary['Input']['trainable'] = 0
        summary['Input']['shared'] = 0


            line_format = '{:>20}  {:>42} {:>15}'
            print(line_format.format('Layer (type)', 'Output Shape', 'Param #'))
            total_params = 0
            trainable_params = 0
            shared_params = 0
            for layer in summary:
                total_params += summary[layer]['n_params']
                trainable_params += summary[layer]['trainable']
                shared_params += summary[layer]['shared']
            print('Parameters in forward computation graph, duplicate included')
            print('   Total params: ' + str(total_params))
            print('   Trainable params: ' + str(trainable_params))
            print('   Non-trainable params: ' + str(total_params - trainable_params))
            print('Shared params in forward computation graph: ' + str(shared_params))
            print('Unique parameters in model: ' + str(total_params - shared_params))
            for h in hooks:

[docs]class HybridBlock(Block): """`HybridBlock` supports forwarding with both Symbol and NDArray. `HybridBlock` is similar to `Block`, with a few differences:: import mxnet as mx from mxnet.gluon import HybridBlock, nn class Model(HybridBlock): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) # use name_scope to give child Blocks appropriate names. with self.name_scope(): self.dense0 = nn.Dense(20) self.dense1 = nn.Dense(20) def hybrid_forward(self, F, x): x = F.relu(self.dense0(x)) return F.relu(self.dense1(x)) model = Model() model.initialize(ctx=mx.cpu(0)) model.hybridize() model(mx.nd.zeros((10, 10), ctx=mx.cpu(0))) Forward computation in :py:class:`HybridBlock` must be static to work with :py:class:`Symbol` s, i.e. you cannot call :py:meth:`NDArray.asnumpy`, :py:attr:`NDArray.shape`, :py:attr:`NDArray.dtype`, `NDArray` indexing (`x[i]`) etc on tensors. Also, you cannot use branching or loop logic that bases on non-constant expressions like random numbers or intermediate results, since they change the graph structure for each iteration. Before activating with :py:meth:`hybridize()`, :py:class:`HybridBlock` works just like normal :py:class:`Block`. After activation, :py:class:`HybridBlock` will create a symbolic graph representing the forward computation and cache it. On subsequent forwards, the cached graph will be used instead of :py:meth:`hybrid_forward`. Please see references for detailed tutorial. References ---------- `Hybrid - Faster training and easy deployment <>`_ """ def __init__(self, prefix=None, params=None): super(HybridBlock, self).__init__(prefix=prefix, params=params) self._cached_graph = () self._cached_op = None self._out_format = None self._in_format = None self._active = False self._flags = [] self._callback = None self._monitor_all = False def __setattr__(self, name, value): """Registers parameters.""" super(HybridBlock, self).__setattr__(name, value) if isinstance(value, HybridBlock): self._clear_cached_op() def _get_graph(self, *args): if not self._cached_graph: flatten_args, self._in_format = _flatten(args, "input") flatten_inputs = [] symbol_inputs = [] cnt = 0 real_arg_num = sum([ele is not None for ele in flatten_args]) if real_arg_num == 0: raise ValueError('All args are None and we do not support such a case.' ' Received args={}'.format(args)) for arg in flatten_args: if arg is not None: if real_arg_num > 1: arg_sym = symbol.var('data{}'.format(cnt)) else: arg_sym = symbol.var('data') if isinstance(arg, _mx_np.ndarray): arg_sym = arg_sym.as_np_ndarray() cnt += 1 flatten_inputs.append(arg_sym) symbol_inputs.append(arg_sym) else: flatten_inputs.append(None) grouped_inputs = _regroup(flatten_inputs, self._in_format) params = {i: j.var() for i, j in self._reg_params.items()} with self.name_scope(): out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter out, self._out_format = _flatten(out, "output") self._cached_graph = symbol_inputs, symbol.Group(out, _check_same_symbol_type(out)) return self._cached_graph def _build_cache(self, *args): data, out = self._get_graph(*args) data_names = { i for i, data in enumerate(data)} params = self.collect_params() input_names = out.list_inputs() param_names = set(params.keys()) expected_names = set(input_names) for name in expected_names: assert name in param_names or name in data_names, \ "Unknown input to HybridBlock: %s" %name used_data_names = [i for i in data_names if i in expected_names] if len(used_data_names) != len(data_names): unused = ', '.join(['%d-th'%i for name, i in data_names.items() if name not in expected_names]) warnings.warn("The %s input to HybridBlock is not used by any " "computation. Is this intended?"%unused, stacklevel=4) used_param_names = [i for i in param_names if i in expected_names] if len(used_param_names) != len(param_names): unused = ', '.join(list(param_names - set(used_param_names))) warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) data_indices = [] param_indices = [] self._cached_op_args = [] for i, name in enumerate(input_names): if name in data_names: data_indices.append(i) self._cached_op_args.append((True, data_names[name])) else: param_indices.append(i) self._cached_op_args.append((False, params[name])) flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ self._flags self._cached_op = ndarray.CachedOp(out, flags) def _deferred_infer_shape(self, *args): try: self.infer_shape(*args) except Exception as e: error_msg = "Deferred initialization failed because shape"\ " cannot be inferred. {}".format(e) raise ValueError(error_msg) def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) assert self._cached_op, "cached op is not None" if self._callback: self._cached_op._register_op_hook(self._callback, self._monitor_all) if len(self._flags) >= 2 and (self._flags[1] or self._flags[0]): warnings.warn("register_op_hook is experimental when static_alloc=True / static_shape=True " " and may not work correctly") args, fmt = _flatten(args, "input") if fmt != self._in_format: # Do not raise in the case that the fmt or stored_fmt ends with None and # We are relying on the default values. if len(self._in_format) > len(fmt): valid = all([self._in_format[i] == -1 for i in range(len(fmt), len(self._in_format))]) valid = valid and (fmt == self._in_format[:len(fmt)]) elif len(self._in_format) < len(fmt): valid = all([fmt[i] == -1 for i in range(len(self._in_format), len(fmt))]) valid = valid and (fmt[:len(self._in_format)] == self._in_format) else: valid = False if not valid: raise ValueError("The argument structure of HybridBlock does not match" " the cached version. Stored format = {}, input format = {}" .format(fmt, self._in_format)) args_without_none = [ele for ele in args if ele is not None] try: cargs = [args_without_none[i] if is_arg else for is_arg, i in self._cached_op_args] except DeferredInitializationError: self._deferred_infer_shape(*args) cargs = [] for is_arg, i in self._cached_op_args: if is_arg: cargs.append(args_without_none[i]) else: i._finish_deferred_init() cargs.append( out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out] return _regroup(out, self._out_format) def _clear_cached_op(self): self._cached_graph = () self._cached_op = None
[docs] def register_child(self, block, name=None): if not isinstance(block, HybridBlock): raise ValueError( "Children of HybridBlock must also be HybridBlock, " \ "but %s has type %s. If you are using Sequential, " \ "please try HybridSequential instead."%( str(block), str(type(block)))) super(HybridBlock, self).register_child(block, name) self._clear_cached_op()
[docs] def hybridize(self, active=True, **kwargs): self._active = active self._flags = list(kwargs.items()) self._clear_cached_op() if active and self._forward_hooks or self._forward_pre_hooks: warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. ' 'If "{block}" is a child of HybridBlock, the hooks will not take effect.' .format(block=self)) super(HybridBlock, self).hybridize(active, **kwargs)
[docs] def cast(self, dtype): self._clear_cached_op() super(HybridBlock, self).cast(dtype)
def _infer_attrs(self, infer_fn, attr, *args): """Generic infer attributes.""" inputs, out = self._get_graph(*args) args, _ = _flatten(args, "input") args_without_none = [ele for ele in args if ele is not None] with warnings.catch_warnings(record=True) as w: arg_attrs, _, aux_attrs = getattr(out, infer_fn)( **{ getattr(j, attr) for i, j in zip(inputs, args_without_none)}) if arg_attrs is None: raise ValueError(w[0].message) sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)} sdict.update({name : attr for name, attr in \ zip(out.list_auxiliary_states(), aux_attrs)}) for i in self.collect_params().values(): setattr(i, attr, sdict[])
[docs] def infer_shape(self, *args): """Infers shape of Parameters from inputs.""" self._infer_attrs('infer_shape', 'shape', *args)
[docs] def infer_type(self, *args): """Infers data type of Parameters from inputs.""" self._infer_attrs('infer_type', 'dtype', *args)
[docs] def export(self, path, epoch=0, remove_amp_cast=True): """Export HybridBlock to json format that can be loaded by `SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface. .. note:: When there are only one input, it will have name `data`. When there Are more than one inputs, they will be named as `data0`, `data1`, etc. Parameters ---------- path : str Path to save model. Two files `path-symbol.json` and `path-xxxx.params` will be created, where xxxx is the 4 digits epoch number. epoch : int Epoch number of saved model. """ if not self._cached_graph: raise RuntimeError( "Please first call block.hybridize() and then run forward with " "this block at least once before calling export.") sym = self._cached_graph[1]'%s-symbol.json'%path, remove_amp_cast=remove_amp_cast) arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} for name, param in self.collect_params().items(): if name in arg_names: arg_dict['arg:%s'%name] = param._reduce() else: assert name in aux_names arg_dict['aux:%s'%name] = param._reduce() save_fn = if is_np_array() else save_fn('%s-%04d.params'%(path, epoch), arg_dict)
[docs] def register_op_hook(self, callback, monitor_all=False): """Install op hook for block recursively. Parameters ---------- callback : function Takes a string and a NDArrayHandle. monitor_all : bool, default False If true, monitor both input and output, otherwise monitor output only. """ self._callback = callback self._monitor_all = monitor_all for cld in self._children.values(): cld._callback = callback cld._monitor_all = monitor_all
[docs] def forward(self, x, *args): """Defines the forward computation. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`.""" flatten_args = _flatten([x] + list(args), 'inputs')[0] is_ndarray = None ctx = None exist_sym_nd = False for ele in flatten_args: if isinstance(ele, NDArray): if is_ndarray is False: raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' ' types for the input.\n' 'Received types are: {}.' .format([type(ele) for ele in flatten_args])) is_ndarray = True exist_sym_nd = True ctx = ele.context elif isinstance(ele, Symbol): if is_ndarray: raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' ' types for the input.\n' 'Received types are: {}.' .format([type(ele) for ele in flatten_args])) is_ndarray = False exist_sym_nd = True else: assert ele is None, 'Only support None, NDArray and Symbol as the input' if not exist_sym_nd: raise ValueError('There must at least one NDArray or Symbol in the input, received') if is_ndarray: with ctx: if self._active: return self._call_cached_op(x, *args) try: params = {k: for k, v in self._reg_params.items()} except DeferredInitializationError: self._deferred_infer_shape(x, *args) for _, v in self.params.items(): v._finish_deferred_init() params = {k: for k, v in self._reg_params.items()} return self.hybrid_forward(ndarray, x, *args, **params) params = {i: j.var() for i, j in self._reg_params.items()} with self.name_scope(): return self.hybrid_forward(symbol, x, *args, **params)
[docs] def hybrid_forward(self, F, x, *args, **kwargs): """Overrides to construct symbolic graph for this `Block`. Parameters ---------- x : Symbol or NDArray The first input tensor. *args : list of Symbol or list of NDArray Additional input tensors. """ # pylint: disable= invalid-name raise NotImplementedError
def _common_prefix(names): """Get the common prefix for all names""" if not names: return '' prefix = names[0] for name in names: i = 0 while i < len(prefix) and i < len(name) and prefix[i] == name[i]: i += 1 prefix = prefix[:i] return prefix class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models as feature extractors. For example, you may want to extract the output from fc2 layer in AlexNet. Parameters ---------- outputs : Symbol or list of Symbol The desired output for SymbolBlock. inputs : Symbol or list of Symbol The Variables in output's argument that should be used as inputs. params : ParameterDict Parameter dictionary for arguments and auxililary states of outputs that are not inputs. Examples -------- >>> # To extract the feature from fc1 and fc2 layers of AlexNet: >>> alexnet =, ctx=mx.cpu(), prefix='model_') >>> inputs = mx.sym.var('data') >>> out = alexnet(inputs) >>> internals = out.get_internals() >>> print(internals.list_outputs()) ['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...] >>> outputs = [internals['model_dense0_relu_fwd_output'], internals['model_dense1_relu_fwd_output']] >>> # Create SymbolBlock that shares parameters with alexnet >>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params()) >>> x = mx.nd.random.normal(shape=(16, 3, 224, 224)) >>> print(feat_model(x)) """ @staticmethod def imports(symbol_file, input_names, param_file=None, ctx=None): """Import model previously saved by `HybridBlock.export` or `Module.save_checkpoint` as a SymbolBlock for use in Gluon. Parameters ---------- symbol_file : str Path to symbol file. input_names : list of str List of input variable names param_file : str, optional Path to parameter file. ctx : Context, default None The context to initialize SymbolBlock on. Returns ------- SymbolBlock SymbolBlock loaded from symbol and parameter files. Examples -------- >>> net1 = ... prefix='resnet', pretrained=True) >>> net1.hybridize() >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32)) >>> out1 = net1(x) >>> net1.export('net1', epoch=1) >>> >>> net2 = gluon.SymbolBlock.imports( ... 'net1-symbol.json', ['data'], 'net1-0001.params') >>> out2 = net2(x) """ sym = symbol.load(symbol_file) if isinstance(input_names, str): input_names = [input_names] if param_file is None: # Get a valid type inference by using fp32 inputs = [symbol.var(i, dtype=mx_real_t) for i in input_names] else: # Do not specify type, rely on saved params type instead inputs = [symbol.var(i) for i in input_names] ret = SymbolBlock(sym, inputs) if param_file is not None: ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved') return ret def __repr__(self): s = '{name}(\n{modstr}\n)' modstr = '\n'.join(['{block} : {numinputs} -> {numoutputs}'.format(block=self._cached_graph[1], numinputs=len(self._cached_graph[0]), numoutputs=len(self._cached_graph[1]. list_outputs()))]) return s.format(name=self.__class__.__name__, modstr=modstr) def __init__(self, outputs, inputs, params=None): super(SymbolBlock, self).__init__(prefix=None, params=None) self._prefix = '' self._params = ParameterDict('', params) if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1: inputs = [inputs] if isinstance(outputs, (list, tuple)) and len(outputs) == 1: outputs = outputs[0] syms, self._in_format = _flatten(inputs, "input") out, self._out_format = _flatten(outputs, "output") out = symbol.Group(out, _check_same_symbol_type(out)) input_names = set() for i in syms: assert len(i.get_internals().list_outputs()) == 1, \ "Input symbols must be variable, but %s is an output of operators"%str(i) input_names.add( # check if any symbol is row_sparse row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse'] for i in out: for j in i.get_internals(): assert(j.attr("__storage_type__") != str(row_sparse_storage)), \ "SymbolBlock doesn't support Parameter '%s' because its storage " \ "type is 'row_sparse'." % # Infer type of parameters. Without this, every parameter will be created with # default type i.e., fp32 arg_params = out.list_arguments() aux_params = out.list_auxiliary_states() arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) for i, arg in enumerate(arg_params): if arg not in input_names: self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i]) for i, aux in enumerate(aux_params): if aux not in input_names: self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) self._cached_graph = syms, out len_prefix = len(_common_prefix(list(self._params.keys()))) self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()} def forward(self, x, *args): if isinstance(x, NDArray): with x.context: return self._call_cached_op(x, *args) assert isinstance(x, Symbol), \ "HybridBlock requires the first argument to forward be either " \ "Symbol or NDArray, but got %s"%type(x) args, in_fmt = _flatten([x] + list(args), "input") assert in_fmt == self._in_format, "Invalid input format" ret = copy.copy(self._cached_graph[1]) ret._compose(**{ v for k, v in zip(self._cached_graph[0], args)}) return _regroup(list(ret), self._out_format) def _clear_cached_op(self): tmp = self._cached_graph super(SymbolBlock, self)._clear_cached_op() self._cached_graph = tmp def cast(self, dtype): self._clear_cached_op() super(SymbolBlock, self).cast(dtype) def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t): """Utility function that helps in inferring DType of args and auxs params from given input param. Parameters ---------- in_params: List of Symbol List of input symbol variables. out_params: Symbol Output symbol variable. arg_params: List of Str List of names of argument parametrs. aux_params: List of Str List of names of auxiliary parameters. default_dtype: numpy.dtype or str, default 'float32' Default data type for arg_params and aux_params, if unable to infer the type. Returns ------- arg_types: List of numpy.dtype List of arg_params type. Order is same as arg_params. Defaults to 'float32', if unable to infer type. aux_types: List of numpy.dtype List of aux_params type. Order is same as aux_params. Defaults to 'float32', if unable to infer type. """ arg_types = None aux_types = None # Get Input symbol details. This will be used to infer types of # other parameters. input_sym_names = [ for in_param in in_params] # Try to infer input types. If not successful, we will set default dtype. # If successful, we will try to infer other params in the graph. input_sym_arg_types = [] can_infer_input_type = True for in_param in in_params: input_sym_arg_type = in_param.infer_type()[0] if not input_sym_arg_type or len(input_sym_arg_type) < 1: can_infer_input_type = False break else: input_sym_arg_types.append(in_param.infer_type()[0][0]) # Try to infer types of other parameters. if can_infer_input_type: params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)} try: arg_types, _, aux_types = out_params.infer_type(**params) except MXNetError: # Cannot infer type with current input arg_types, aux_types = None, None if arg_types is None or len(arg_types) != len(arg_params): arg_types = [] for _ in arg_params: arg_types.append(default_dtype) if aux_types is None or len(aux_types) != len(aux_params): aux_types = [] for _ in aux_params: aux_types.append(default_dtype) return (arg_types, aux_types)