|
- # Copyright (c) OpenMMLab. All rights reserved.
- from collections import OrderedDict
-
- from mmcv.utils import print_log
-
- from mmdet.core import eval_map, eval_recalls
- from .builder import DATASETS
- from .xml_style import XMLDataset
-
-
- @DATASETS.register_module()
- class VOCDataset(XMLDataset):
-
- CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
- 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
- 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
- 'tvmonitor')
-
- def __init__(self, **kwargs):
- super(VOCDataset, self).__init__(**kwargs)
- if 'VOC2007' in self.img_prefix:
- self.year = 2007
- elif 'VOC2012' in self.img_prefix:
- self.year = 2012
- else:
- raise ValueError('Cannot infer dataset year from img_prefix')
-
- def evaluate(self,
- results,
- metric='mAP',
- logger=None,
- proposal_nums=(100, 300, 1000),
- iou_thr=0.5,
- scale_ranges=None):
- """Evaluate in VOC protocol.
-
- Args:
- results (list[list | tuple]): Testing results of the dataset.
- metric (str | list[str]): Metrics to be evaluated. Options are
- 'mAP', 'recall'.
- logger (logging.Logger | str, optional): Logger used for printing
- related information during evaluation. Default: None.
- proposal_nums (Sequence[int]): Proposal number used for evaluating
- recalls, such as recall@100, recall@1000.
- Default: (100, 300, 1000).
- iou_thr (float | list[float]): IoU threshold. Default: 0.5.
- scale_ranges (list[tuple], optional): Scale ranges for evaluating
- mAP. If not specified, all bounding boxes would be included in
- evaluation. Default: None.
-
- Returns:
- dict[str, float]: AP/recall metrics.
- """
-
- if not isinstance(metric, str):
- assert len(metric) == 1
- metric = metric[0]
- allowed_metrics = ['mAP', 'recall']
- if metric not in allowed_metrics:
- raise KeyError(f'metric {metric} is not supported')
- annotations = [self.get_ann_info(i) for i in range(len(self))]
- eval_results = OrderedDict()
- iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
- if metric == 'mAP':
- assert isinstance(iou_thrs, list)
- if self.year == 2007:
- ds_name = 'voc07'
- else:
- ds_name = self.CLASSES
- mean_aps = []
- for iou_thr in iou_thrs:
- print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
- # Follow the official implementation,
- # http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar
- # we should use the legacy coordinate system in mmdet 1.x,
- # which means w, h should be computed as 'x2 - x1 + 1` and
- # `y2 - y1 + 1`
- mean_ap, _ = eval_map(
- results,
- annotations,
- scale_ranges=None,
- iou_thr=iou_thr,
- dataset=ds_name,
- logger=logger,
- use_legacy_coordinate=True)
- mean_aps.append(mean_ap)
- eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
- eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
- elif metric == 'recall':
- gt_bboxes = [ann['bboxes'] for ann in annotations]
- recalls = eval_recalls(
- gt_bboxes,
- results,
- proposal_nums,
- iou_thrs,
- logger=logger,
- use_legacy_coordinate=True)
- for i, num in enumerate(proposal_nums):
- for j, iou_thr in enumerate(iou_thrs):
- eval_results[f'recall@{num}@{iou_thr}'] = recalls[i, j]
- if recalls.shape[1] > 1:
- ar = recalls.mean(axis=1)
- for i, num in enumerate(proposal_nums):
- eval_results[f'AR@{num}'] = ar[i]
- return eval_results
|