|
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- import warnings
- from collections import OrderedDict
-
- import mmcv
- import numpy as np
- from mmcv.utils import print_log
- from terminaltables import AsciiTable
- from torch.utils.data import Dataset
-
- from mmdet.core import eval_map, eval_recalls
- from .builder import DATASETS
- from .pipelines import Compose
-
-
- @DATASETS.register_module()
- class CustomDataset(Dataset):
- """Custom dataset for detection.
-
- The annotation format is shown as follows. The `ann` field is optional for
- testing.
-
- .. code-block:: none
-
- [
- {
- 'filename': 'a.jpg',
- 'width': 1280,
- 'height': 720,
- 'ann': {
- 'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order.
- 'labels': <np.ndarray> (n, ),
- 'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
- 'labels_ignore': <np.ndarray> (k, 4) (optional field)
- }
- },
- ...
- ]
-
- Args:
- ann_file (str): Annotation file path.
- pipeline (list[dict]): Processing pipeline.
- classes (str | Sequence[str], optional): Specify classes to load.
- If is None, ``cls.CLASSES`` will be used. Default: None.
- data_root (str, optional): Data root for ``ann_file``,
- ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
- test_mode (bool, optional): If set True, annotation will not be loaded.
- filter_empty_gt (bool, optional): If set true, images without bounding
- boxes of the dataset's classes will be filtered out. This option
- only works when `test_mode=False`, i.e., we never filter images
- during tests.
- """
-
- CLASSES = None
-
- def __init__(self,
- ann_file,
- pipeline,
- classes=None,
- data_root=None,
- img_prefix='',
- seg_prefix=None,
- proposal_file=None,
- test_mode=False,
- filter_empty_gt=True):
- self.ann_file = ann_file
- self.data_root = data_root
- self.img_prefix = img_prefix
- self.seg_prefix = seg_prefix
- self.proposal_file = proposal_file
- self.test_mode = test_mode
- self.filter_empty_gt = filter_empty_gt
- self.CLASSES = self.get_classes(classes)
-
- # join paths if data_root is specified
- if self.data_root is not None:
- if not osp.isabs(self.ann_file):
- self.ann_file = osp.join(self.data_root, self.ann_file)
- if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
- self.img_prefix = osp.join(self.data_root, self.img_prefix)
- if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
- self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
- if not (self.proposal_file is None
- or osp.isabs(self.proposal_file)):
- self.proposal_file = osp.join(self.data_root,
- self.proposal_file)
- # load annotations (and proposals)
- self.data_infos = self.load_annotations(self.ann_file)
-
- if self.proposal_file is not None:
- self.proposals = self.load_proposals(self.proposal_file)
- else:
- self.proposals = None
-
- # filter images too small and containing no annotations
- if not test_mode:
- '''valid_inds = self._filter_imgs()
- self.data_infos = [self.data_infos[i] for i in valid_inds]
- if self.proposals is not None:
- self.proposals = [self.proposals[i] for i in valid_inds]'''
- # set group flag for the sampler
- self._set_group_flag()
-
- # processing pipeline
- self.pipeline = Compose(pipeline)
-
- def __len__(self):
- """Total number of samples of data."""
- return len(self.data_infos)
-
- def load_annotations(self, ann_file):
- """Load annotation from annotation file."""
- return mmcv.load(ann_file)
-
- def load_proposals(self, proposal_file):
- """Load proposal from proposal file."""
- return mmcv.load(proposal_file)
-
- def get_ann_info(self, idx):
- """Get annotation by index.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- dict: Annotation info of specified index.
- """
-
- return self.data_infos[idx]['ann']
-
- def get_cat_ids(self, idx):
- """Get category ids by index.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- list[int]: All categories in the image of specified index.
- """
-
- return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist()
-
- def pre_pipeline(self, results):
- """Prepare results dict for pipeline."""
- results['img_prefix'] = self.img_prefix
- results['seg_prefix'] = self.seg_prefix
- results['proposal_file'] = self.proposal_file
- results['bbox_fields'] = []
- results['mask_fields'] = []
- results['seg_fields'] = []
-
- def _filter_imgs(self, min_size=32):
- """Filter images too small."""
- if self.filter_empty_gt:
- warnings.warn(
- 'CustomDataset does not support filtering empty gt images.')
- valid_inds = []
- for i, img_info in enumerate(self.data_infos):
- if min(img_info['width'], img_info['height']) >= min_size:
- valid_inds.append(i)
- return valid_inds
-
- def _set_group_flag(self):
- """Set flag according to image aspect ratio.
-
- Images with aspect ratio greater than 1 will be set as group 1,
- otherwise group 0.
- """
- self.flag = np.zeros(len(self), dtype=np.uint8)
- for i in range(len(self)):
- img_info = self.data_infos[i]
- if img_info['width'] / img_info['height'] > 1:
- self.flag[i] = 1
-
- def _rand_another(self, idx):
- """Get another random index from the same group as the given index."""
- pool = np.where(self.flag == self.flag[idx])[0]
- return np.random.choice(pool)
-
- def __getitem__(self, idx):
- """Get training/test data after pipeline.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- dict: Training/test data (with annotation if `test_mode` is set \
- True).
- """
-
- if self.test_mode:
- return self.prepare_test_img(idx)
- while True:
- data = self.prepare_train_img(idx)
- if data is None:
- idx = self._rand_another(idx)
- continue
- return data
-
- def prepare_train_img(self, idx):
- """Get training data and annotations after pipeline.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- dict: Training data and annotation after pipeline with new keys \
- introduced by pipeline.
- """
-
- img_info = self.data_infos[idx]
- ann_info = self.get_ann_info(idx)
- results = dict(img_info=img_info, ann_info=ann_info)
- if self.proposals is not None:
- results['proposals'] = self.proposals[idx]
- self.pre_pipeline(results)
- return self.pipeline(results)
-
- def prepare_test_img(self, idx):
- """Get testing data after pipeline.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- dict: Testing data after pipeline with new keys introduced by \
- pipeline.
- """
-
- img_info = self.data_infos[idx]
- results = dict(img_info=img_info)
- if self.proposals is not None:
- results['proposals'] = self.proposals[idx]
- self.pre_pipeline(results)
- return self.pipeline(results)
-
- @classmethod
- def get_classes(cls, classes=None):
- """Get class names of current dataset.
-
- Args:
- classes (Sequence[str] | str | None): If classes is None, use
- default CLASSES defined by builtin dataset. If classes is a
- string, take it as a file name. The file contains the name of
- classes where each line contains one class name. If classes is
- a tuple or list, override the CLASSES defined by the dataset.
-
- Returns:
- tuple[str] or list[str]: Names of categories of the dataset.
- """
- if classes is None:
- return cls.CLASSES
-
- if isinstance(classes, str):
- # take it as a file path
- try:
- class_names = mmcv.list_from_file(classes)
- except:
- class_names = [classes]
- elif isinstance(classes, (tuple, list)):
- class_names = classes
- else:
- raise ValueError(f'Unsupported type {type(classes)} of classes.')
-
- return class_names
-
- def format_results(self, results, **kwargs):
- """Place holder to format result to dataset specific output."""
-
- def evaluate(self,
- results,
- metric='mAP',
- logger=None,
- proposal_nums=(100, 300, 1000),
- iou_thr=0.5,
- scale_ranges=None):
- """Evaluate the dataset.
-
- Args:
- results (list): Testing results of the dataset.
- metric (str | list[str]): Metrics to be evaluated.
- logger (logging.Logger | None | str): Logger used for printing
- related information during evaluation. Default: None.
- proposal_nums (Sequence[int]): Proposal number used for evaluating
- recalls, such as recall@100, recall@1000.
- Default: (100, 300, 1000).
- iou_thr (float | list[float]): IoU threshold. Default: 0.5.
- scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
- Default: None.
- """
-
- if not isinstance(metric, str):
- assert len(metric) == 1
- metric = metric[0]
- allowed_metrics = ['mAP', 'recall']
- if metric not in allowed_metrics:
- raise KeyError(f'metric {metric} is not supported')
- annotations = [self.get_ann_info(i) for i in range(len(self))]
- eval_results = OrderedDict()
- iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
- if metric == 'mAP':
- assert isinstance(iou_thrs, list)
- mean_aps = []
- for iou_thr in iou_thrs:
- print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
- mean_ap, _ = eval_map(
- results,
- annotations,
- scale_ranges=scale_ranges,
- iou_thr=iou_thr,
- dataset=self.CLASSES,
- logger=logger)
- mean_aps.append(mean_ap)
- eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
- eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
- elif metric == 'recall':
- gt_bboxes = [ann['bboxes'] for ann in annotations]
- recalls = eval_recalls(
- gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
- for i, num in enumerate(proposal_nums):
- for j, iou in enumerate(iou_thrs):
- eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
- if recalls.shape[1] > 1:
- ar = recalls.mean(axis=1)
- for i, num in enumerate(proposal_nums):
- eval_results[f'AR@{num}'] = ar[i]
- return eval_results
-
- def __repr__(self):
- """Print the number of instance number."""
- dataset_type = 'Test' if self.test_mode else 'Train'
- result = (f'\n{self.__class__.__name__} {dataset_type} dataset '
- f'with number of images {len(self)}, '
- f'and instance counts: \n')
- if self.CLASSES is None:
- result += 'Category names are not provided. \n'
- return result
- instance_count = np.zeros(len(self.CLASSES) + 1).astype(int)
- # count the instance number in each image
- for idx in range(len(self)):
- label = self.get_ann_info(idx)['labels']
- unique, counts = np.unique(label, return_counts=True)
- if len(unique) > 0:
- # add the occurrence number to each class
- instance_count[unique] += counts
- else:
- # background is the last index
- instance_count[-1] += 1
- # create a table with category count
- table_data = [['category', 'count'] * 5]
- row_data = []
- for cls, count in enumerate(instance_count):
- if cls < len(self.CLASSES):
- row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}']
- else:
- # add the background number
- row_data += ['-1 background', f'{count}']
- if len(row_data) == 10:
- table_data.append(row_data)
- row_data = []
- if len(row_data) >= 2:
- if row_data[-1] == '0':
- row_data = row_data[:-2]
- if len(row_data) >= 2:
- table_data.append([])
- table_data.append(row_data)
-
- table = AsciiTable(table_data)
- result += table.table
- return result
|