Source code for gluoncv.data.recordio.detection
"""Detection dataset from RecordIO files."""
from __future__ import absolute_import
from __future__ import division
import numpy as np
from mxnet import gluon
def _transform_label(label, height=None, width=None):
label = np.array(label).ravel()
header_len = int(label[0]) # label header
label_width = int(label[1]) # the label width for each object, >= 5
if label_width < 5:
raise ValueError(
"Label info for each object should >= 5, given {}".format(label_width))
min_len = header_len + 5
if len(label) < min_len:
raise ValueError(
"Expected label length >= {}, got {}".format(min_len, len(label)))
if (len(label) - header_len) % label_width:
raise ValueError(
"Broken label of size {}, cannot reshape into (N, {}) "
"if header length {} is excluded".format(len(label), label_width, header_len))
gcv_label = label[header_len:].reshape(-1, label_width)
# swap columns, gluon-cv requires [xmin-ymin-xmax-ymax-id-extra0-extra1-xxx]
ids = gcv_label[:, 0].copy()
gcv_label[:, :4] = gcv_label[:, 1:5]
gcv_label[:, 4] = ids
# restore to absolute coordinates
if height is not None:
gcv_label[:, (0, 2)] *= width
if width is not None:
gcv_label[:, (1, 3)] *= height
return gcv_label
[docs]class RecordFileDetection(gluon.data.vision.ImageRecordDataset):
"""Detection dataset loaded from record file.
The supported record file is using the same format used by
:py:meth:`mxnet.image.ImageDetIter` and :py:meth:`mxnet.io.ImageDetRecordIter`.
Checkout :ref:`lst_record_dataset` for tutorial of how to prepare this file.
.. note::
We suggest you to use ``RecordFileDetection`` only if you are familiar with
the record files.
Parameters
----------
filename : str
Path of the record file. It require both *.rec and *.idx file in the same
directory, where raw image and labels are stored in *.rec file for better
IO performance, *.idx file is used to provide random access to the binary file.
coord_normalized : boolean
Indicate whether bounding box coordinates have been normalized to (0, 1) in labels.
If so, we will rescale back to absolute coordinates by multiplying width or height.
Examples
--------
>>> record_dataset = RecordFileDetection('train.rec')
>>> img, label = record_dataset[0]
>>> print(img.shape, label.shape)
(512, 512, 3) (1, 5)
"""
def __init__(self, filename, coord_normalized=True):
super(RecordFileDetection, self).__init__(filename)
self._coord_normalized = coord_normalized
def __getitem__(self, idx):
img, label = super(RecordFileDetection, self).__getitem__(idx)
h, w, _ = img.shape
if self._coord_normalized:
label = _transform_label(label, h, w)
else:
label = _transform_label(label)
return img, label