.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "build/examples_tracking/train_siamrpn.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_build_examples_tracking_train_siamrpn.py: 02. Train SiamRPN on COCO、VID、DET、Youtube_bb ================================================== This is a Single Obejct Tracking tutorial using Gluon CV toolkit, a step-by-step example. The readers should have basic knowledge of deep learning and should be familiar with Gluon API. New users may first go through `A 60-minute Gluon Crash Course `_. You can `Start Training Now`_ or `Dive into Deep`_. Start Training Now ~~~~~~~~~~~~~~~~~~ .. note:: Feel free to skip the tutorial because the training script is self-complete and ready to launch. :download:`Download Full Python Script: train.py<../../../scripts/tracking/train.py>` :download:`Download Full Python Script: test.py<../../../scripts/tracking/test.py>` Example training command:: python train.py --ngpus 8 --epochs 50 --base-lr 0.005 Example test command:: python test.py --model-path --results-path Please checkout the `model_zoo <../model_zoo/index.html#single_object_tracking>`_ for training and test commands of reproducing the pretrained model. Network Structure ----------------- First, let's import the necessary libraries into python. .. GENERATED FROM PYTHON SOURCE LINES 36-48 .. code-block:: default import mxnet as mx import time import numpy as np from mxnet import gluon, nd, autograd from mxnet.contrib import amp import gluoncv from gluoncv.utils import LRScheduler, LRSequential, split_and_load from gluoncv.data.tracking_data.track import TrkDataset from gluoncv.model_zoo import get_model from gluoncv.loss import SiamRPNLoss .. GENERATED FROM PYTHON SOURCE LINES 49-52 `SiamRPN `_ is a widely adopted Single Object Tracking method. Send the template frame and detection frame to the siamese network, and get the score map and coordinate regression of the anchor through the RPN network and cross correlation layers. .. GENERATED FROM PYTHON SOURCE LINES 53-89 .. code-block:: default # number of GPUs to use num_gpus = 1 ctx = [mx.cpu(0)] batch_size = 32 # adjust to 128 if memory is sufficient epochs = 1 # Get the model siamrpn_alexnet with SiamRPN backbone net = get_model('siamrpn_alexnet_v2_otb15', bz=batch_size, is_train=True, ctx=ctx) net.collect_params().reset_ctx(ctx) print(net) # We provide Single Obejct datasets in :class:`gluoncv.data`. # For example, we can easily get the vid,det,coco dataset: '''``python scripts/datasets/ilsvrc_det.py`` ``python scripts/datasets/ilsvrc_vid.py`` ``python scripts/datasets/coco_tracking.py``''' # If you want to download youtube_bb dataset,you can You can follow it from the following `link `: # prepare dataset and dataloader train_dataset = TrkDataset(train_epoch=epochs) print('Training images:', len(train_dataset)) workers = 0 train_loader = gluon.data.DataLoader(train_dataset, batch_size=batch_size, last_batch='discard', num_workers=workers) def train_batch_fn(data, ctx): """split and load data in GPU""" template = split_and_load(data[0], ctx_list=ctx, batch_axis=0) search = split_and_load(data[1], ctx_list=ctx, batch_axis=0) label_cls = split_and_load(data[2], ctx_list=ctx, batch_axis=0) label_loc = split_and_load(data[3], ctx_list=ctx, batch_axis=0) label_loc_weight = split_and_load(data[4], ctx_list=ctx, batch_axis=0) return template, search, label_cls, label_loc, label_loc_weight .. rst-class:: sphx-glr-script-out Out: .. code-block:: none SiamRPN( (backbone): AlexNetLegacy( (features): HybridSequential( (0): Conv2D(None -> 96, kernel_size=(11, 11), stride=(2, 2)) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (2): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (3): Activation(relu) (4): Conv2D(None -> 256, kernel_size=(5, 5), stride=(1, 1)) (5): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (6): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (7): Activation(relu) (8): Conv2D(None -> 384, kernel_size=(3, 3), stride=(1, 1)) (9): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (10): Activation(relu) (11): Conv2D(None -> 384, kernel_size=(3, 3), stride=(1, 1)) (12): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (13): Activation(relu) (14): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1)) (15): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) ) ) (rpn_head): DepthwiseRPN( (cls): DepthwiseXCorr( (conv_kernel): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (2): Activation(relu) ) (conv_search): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (2): Activation(relu) ) (head): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (2): Activation(relu) (3): Conv2D(None -> 10, kernel_size=(1, 1), stride=(1, 1)) ) ) (loc): DepthwiseXCorr( (conv_kernel): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (2): Activation(relu) ) (conv_search): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (2): Activation(relu) ) (head): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None) (2): Activation(relu) (3): Conv2D(None -> 20, kernel_size=(1, 1), stride=(1, 1)) ) ) ) ) Training images: 600000 .. GENERATED FROM PYTHON SOURCE LINES 90-97 Training Details ---------------- - Training Losses: We apply Softmax Cross Entropy Loss and L2 loss to train SiamRPN. .. GENERATED FROM PYTHON SOURCE LINES 97-100 .. code-block:: default criterion = SiamRPNLoss(batch_size) .. GENERATED FROM PYTHON SOURCE LINES 101-102 - Learning Rate and Scheduling: .. GENERATED FROM PYTHON SOURCE LINES 102-105 .. code-block:: default lr_scheduler = LRScheduler(mode='step', base_lr=0.005, step_epoch=[0], nepochs=epochs, iters_per_epoch=len(train_loader), power=0.9) .. GENERATED FROM PYTHON SOURCE LINES 106-108 - Dataparallel for multi-gpu training, using cpu for demo only Stochastic gradient descent .. GENERATED FROM PYTHON SOURCE LINES 108-118 .. code-block:: default optimizer = 'sgd' # Set parameters optimizer_params = {'lr_scheduler': lr_scheduler, 'wd': 1e-4, 'momentum': 0.9, 'learning_rate': 0.005} trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) cls_weight = 1.0 loc_weight = 1.2 .. GENERATED FROM PYTHON SOURCE LINES 119-128 Training -------- After all the preparations, we can finally start training! Following is the script. .. note:: In your experiments, we recommend setting ``epochs=50`` for the dataset. We will skip the training in this tutorial .. GENERATED FROM PYTHON SOURCE LINES 128-172 .. code-block:: default epochs = 0 for epoch in range(epochs): loss_total_val = 0 loss_loc_val = 0 loss_cls_val = 0 batch_time = time.time() for i, data in enumerate(train_loader): template, search, label_cls, label_loc, label_loc_weight = train_batch_fn(data, ctx) cls_losses = [] loc_losses = [] total_losses = [] with autograd.record(): for j in range(len(ctx)): cls, loc = net(template[j], search[j]) label_cls_temp = label_cls[j].reshape(-1).asnumpy() pos_index = np.argwhere(label_cls_temp == 1).reshape(-1) neg_index = np.argwhere(label_cls_temp == 0).reshape(-1) if len(pos_index): pos_index = nd.array(pos_index, ctx=ctx[j]) else: pos_index = nd.array(np.array([]), ctx=ctx[j]) if len(neg_index): neg_index = nd.array(neg_index, ctx=ctx[j]) else: neg_index = nd.array(np.array([]), ctx=ctx[j]) cls_loss, loc_loss = criterion(cls, loc, label_cls[j], pos_index, neg_index, label_loc[j], label_loc_weight[j]) total_loss = cls_weight*cls_loss+loc_weight*loc_loss cls_losses.append(cls_loss) loc_losses.append(loc_loss) total_losses.append(total_loss) autograd.backward(total_losses) trainer.step(batch_size) loss_total_val += sum([l.mean().asscalar() for l in total_losses]) / len(total_losses) loss_loc_val += sum([l.mean().asscalar() for l in loc_losses]) / len(loc_losses) loss_cls_val += sum([l.mean().asscalar() for l in cls_losses]) / len(cls_losses) print('Epoch %d iteration %04d/%04d: loc loss %.3f, cls loss %.3f, \ training loss %.3f, batch time %.3f'% \ (epoch, i, len(train_loader), loss_loc_val/(i+1), loss_cls_val/(i+1), loss_total_val/(i+1), time.time()-batch_time)) batch_time = time.time() mx.nd.waitall() .. GENERATED FROM PYTHON SOURCE LINES 173-182 You can `Start Training Now`_. References ---------- .. Bo Li, Junjie Yan, Wei Wu, Zheng Zhu, Xiaolin Hu. \ "High Performance Visual Tracking With Siamese Region Proposal Network。" \ Proceedings of the IEEE conference on computer vision and pattern recognition. 2018. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.340 seconds) .. _sphx_glr_download_build_examples_tracking_train_siamrpn.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: train_siamrpn.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: train_siamrpn.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_