gluoncv.loss

Custom losses. Losses are subclasses of gluon.loss.Loss which is a HybridBlock actually.

FocalLoss

Focal Loss for inbalanced classification.

SSDMultiBoxLoss

Single-Shot Multibox Object Detection Loss.

API Reference

Custom losses. Losses are subclasses of gluon.loss.Loss which is a HybridBlock actually.

class gluoncv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=1, hard_weight=0.5, sparse_label=True, **kwargs)[source]

SoftmaxCrossEntrolyLoss with Teacher model prediction

Parameters
  • temperature (float, default 1) – The temperature parameter to soften teacher prediction.

  • hard_weight (float, default 0.5) – The weight for loss on the one-hot label.

  • sparse_label (bool, default True) – Whether the one-hot label is sparse.

hybrid_forward(F, output, label, soft_target)[source]

Compute loss

class gluoncv.loss.FocalLoss(axis=- 1, alpha=0.25, gamma=2, sparse_label=True, from_logits=False, batch_axis=0, weight=None, num_class=None, eps=1e-12, size_average=True, **kwargs)[source]

Focal Loss for inbalanced classification. Focal loss was described in https://arxiv.org/abs/1708.02002

Parameters
  • axis (int, default -1) – The axis to sum over when computing softmax and entropy.

  • alpha (float, default 0.25) – The alpha which controls loss curve.

  • gamma (float, default 2) – The gamma which controls loss curve.

  • sparse_label (bool, default True) – Whether label is an integer array instead of probability distribution.

  • from_logits (bool, default False) – Whether input is a log probability (usually from log_softmax) instead.

  • batch_axis (int, default 0) – The axis that represents mini-batch.

  • weight (float or None) – Global scalar weight for loss.

  • num_class (int) – Number of classification categories. It is required is sparse_label is True.

  • eps (float) – Eps to avoid numerical issue.

  • size_average (bool, default True) – If True, will take mean of the output loss on every axis except batch_axis.

  • Inputs

    • pred: the prediction tensor, where the batch_axis dimension ranges over batch size and axis dimension ranges over the number of classes.

    • label: the truth tensor. When sparse_label is True, label’s shape should be pred’s shape with the axis dimension removed. i.e. for pred with shape (1,2,3,4) and axis = 2, label’s shape should be (1,2,4) and values should be integers between 0 and 2. If sparse_label is False, label’s shape must be the same as pred and values should be floats in the range [0, 1].

    • sample_weight: element-wise weighting tensor. Must be broadcastable to the same shape as label. For example, if label has shape (64, 10) and you want to weigh each sample in the batch separately, sample_weight should have shape (64, 1).

  • Outputs

    • loss: loss tensor with shape (batch_size,). Dimensions other than batch_axis are averaged out.

hybrid_forward(F, pred, label, sample_weight=None)[source]

Loss forward

class gluoncv.loss.ICNetLoss(weights=(0.4, 0.4, 1.0), height=None, width=None, crop_size=480, ignore_label=- 1, **kwargs)[source]

Weighted SoftmaxCrossEntropyLoss2D for ICNet training

Parameters
  • weights (tuple, default (0.4, 0.4, 1.0)) – The weight for cascade label guidance.

  • ignore_label (int, default -1) – The label to ignore.

hybrid_forward(F, *inputs)[source]

Compute loss

class gluoncv.loss.MixSoftmaxCrossEntropyLoss(aux=True, mixup=False, aux_weight=0.2, ignore_label=- 1, **kwargs)[source]

SoftmaxCrossEntropyLoss2D with Auxiliary Loss

Parameters
  • aux (bool, default True) – Whether to use auxiliary loss.

  • aux_weight (float, default 0.2) – The weight for aux loss.

  • ignore_label (int, default -1) – The label to ignore.

hybrid_forward(F, *inputs, **kwargs)[source]

Compute loss

class gluoncv.loss.MixSoftmaxCrossEntropyOHEMLoss(aux=True, aux_weight=0.2, ignore_label=- 1, **kwargs)[source]

SoftmaxCrossEntropyLoss2D with Auxiliary Loss

Parameters
  • aux (bool, default True) – Whether to use auxiliary loss.

  • aux_weight (float, default 0.2) – The weight for aux loss.

  • ignore_label (int, default -1) – The label to ignore.

hybrid_forward(F, *inputs, **kwargs)[source]

Compute loss

class gluoncv.loss.SSDMultiBoxLoss(negative_mining_ratio=3, rho=1.0, lambd=1.0, min_hard_negatives=0, **kwargs)[source]

Single-Shot Multibox Object Detection Loss.

Note

Since cross device synchronization is required to compute batch-wise statistics, it is slightly sub-optimal compared with non-sync version. However, we find this is better for converged model performance.

Parameters
  • negative_mining_ratio (float, default is 3) – Ratio of negative vs. positive samples.

  • rho (float, default is 1.0) – Threshold for trimmed mean estimators. This is the smooth parameter for the L1-L2 transition.

  • lambd (float, default is 1.0) – Relative weight between classification and box regression loss. The overall loss is computed as \(L = loss_{class} + \lambda \times loss_{loc}\).

  • min_hard_negatives (int, default is 0) – Minimum number of negatives samples.

forward(cls_pred, box_pred, cls_target, box_target)[source]

Compute loss in entire batch across devices.

Parameters
  • cls_pred (mxnet.nd.NDArray) –

  • classes. (Ground-truth) –

  • box_pred (mxnet.nd.NDArray) –

  • bounding-boxes. (Ground-truth) –

  • cls_target (mxnet.nd.NDArray) –

  • classes.

  • box_target (mxnet.nd.NDArray) –

  • bounding-boxes.

Returns

sum_lossesarray with containing the sum of

class prediction and bounding-box regression loss.

cls_losses : array of class prediction loss. box_losses : array of box regression L1 loss.

Return type

tuple of NDArrays

class gluoncv.loss.SegmentationMultiLosses(size_average=True, ignore_label=- 1, **kwargs)[source]

2D Cross Entropy Loss with Multi-Loss

hybrid_forward(F, *inputs, **kwargs)[source]

Compute loss

class gluoncv.loss.SiamRPNLoss(batch_size=128, **kwargs)[source]

Weighted l1 loss and cross entropy loss for SiamRPN training

Parameters

batch_size (int, default 128) – training batch size per device (CPU/GPU).

cross_entropy_loss(F, pred, label, pos_index, neg_index)[source]

Compute cross_entropy_loss

get_cls_loss(F, pred, label, select)[source]

Compute SoftmaxCrossEntropyLoss

hybrid_forward(F, cls_pred, loc_pred, label_cls, pos_index, neg_index, label_loc, label_loc_weight)[source]

Compute loss

weight_l1_loss(F, pred_loc, label_loc, loss_weight)[source]

Compute weight_l1_loss

class gluoncv.loss.YOLOV3Loss(batch_axis=0, weight=None, **kwargs)[source]

Losses of YOLO v3.

Parameters
  • batch_axis (int, default 0) – The axis that represents mini-batch.

  • weight (float or None) – Global scalar weight for loss.

hybrid_forward(F, objness, box_centers, box_scales, cls_preds, objness_t, center_t, scale_t, weight_t, class_t, class_mask)[source]

Compute YOLOv3 losses.

Parameters
  • objness (mxnet.nd.NDArray) – Predicted objectness (B, N), range (0, 1).

  • box_centers (mxnet.nd.NDArray) – Predicted box centers (x, y) (B, N, 2), range (0, 1).

  • box_scales (mxnet.nd.NDArray) – Predicted box scales (width, height) (B, N, 2).

  • cls_preds (mxnet.nd.NDArray) – Predicted class predictions (B, N, num_class), range (0, 1).

  • objness_t (mxnet.nd.NDArray) – Objectness target, (B, N), 0 for negative 1 for positive, -1 for ignore.

  • center_t (mxnet.nd.NDArray) – Center (x, y) targets (B, N, 2).

  • scale_t (mxnet.nd.NDArray) – Scale (width, height) targets (B, N, 2).

  • weight_t (mxnet.nd.NDArray) – Loss Multipliers for center and scale targets (B, N, 2).

  • class_t (mxnet.nd.NDArray) – Class targets (B, N, num_class). It’s relaxed one-hot vector, i.e., (1, 0, 1, 0, 0). It can contain more than one positive class.

  • class_mask (mxnet.nd.NDArray) – 0 or 1 mask array to mask out ignored samples (B, N, num_class).

Returns

obj_loss: sum of objectness logistic loss center_loss: sum of box center logistic regression loss scale_loss: sum of box scale l1 loss cls_loss: sum of per class logistic loss

Return type

tuple of NDArrays