Source code for gluoncv.utils.viz.network
"""Visualize network structure"""
import tempfile
try:
import graphviz
except ImportError:
graphviz = None
import mxnet as mx
from mxnet import gluon
[docs]def plot_network(block, shape=(1, 3, 224, 224), save_prefix=None):
"""Plot network to visualize internal structures.
Parameters
----------
block : mxnet.gluon.HybridBlock
A hybridizable network to be visualized.
shape : tuple of int
Desired input shape, default is (1, 3, 224, 224).
save_prefix : str or None
If not `None`, will save rendered pdf to disk with prefix.
"""
if graphviz is None:
raise RuntimeError("Cannot import graphviz.")
if not isinstance(block, gluon.HybridBlock):
raise ValueError("block must be HybridBlock, given {}".format(type(block)))
data = mx.sym.var('data')
sym = block(data)
if isinstance(sym, tuple):
sym = mx.sym.Group(sym)
a = mx.viz.plot_network(sym, shape={'data':shape},
node_attrs={'shape':'rect', 'fixedsize':'false'})
a.view(tempfile.mktemp('.gv'))
if isinstance(save_prefix, str):
a.render(save_prefix)
[docs]def plot_mxboard(block, logdir='./logs'):
"""Plot network to visualize internal structures.
Parameters
----------
block : mxnet.gluon.HybridBlock
A hybridizable network to be visualized.
logdir : str
The directory to save.
"""
try:
from mxboard import SummaryWriter
except ImportError:
print('mxboard is required. Please install via `pip install mxboard` ' +
'or refer to https://github.com/awslabs/mxboard.')
raise
data = mx.sym.var('data')
sym = block(data)
if isinstance(sym, tuple):
sym = mx.sym.Group(sym)
with SummaryWriter(logdir=logdir) as sw:
sw.add_graph(sym)
usage = '`tensorboard --logdir={} --host=127.0.0.1 --port=8888`'.format(logdir)
print('Log saved. Use: {} to visualize it'.format(usage))