|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
-
- from mmdet.core import bbox2result
- from mmdet.models.builder import DETECTORS
- from ...core.utils import flip_tensor
- from .single_stage import SingleStageDetector
-
-
- @DETECTORS.register_module()
- class CenterNet(SingleStageDetector):
- """Implementation of CenterNet(Objects as Points)
-
- <https://arxiv.org/abs/1904.07850>.
- """
-
- def __init__(self,
- backbone,
- neck,
- bbox_head,
- train_cfg=None,
- test_cfg=None,
- pretrained=None,
- init_cfg=None):
- super(CenterNet, self).__init__(backbone, neck, bbox_head, train_cfg,
- test_cfg, pretrained, init_cfg)
-
- def merge_aug_results(self, aug_results, with_nms):
- """Merge augmented detection bboxes and score.
-
- Args:
- aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
- image.
- with_nms (bool): If True, do nms before return boxes.
-
- Returns:
- tuple: (out_bboxes, out_labels)
- """
- recovered_bboxes, aug_labels = [], []
- for single_result in aug_results:
- recovered_bboxes.append(single_result[0][0])
- aug_labels.append(single_result[0][1])
-
- bboxes = torch.cat(recovered_bboxes, dim=0).contiguous()
- labels = torch.cat(aug_labels).contiguous()
- if with_nms:
- out_bboxes, out_labels = self.bbox_head._bboxes_nms(
- bboxes, labels, self.bbox_head.test_cfg)
- else:
- out_bboxes, out_labels = bboxes, labels
-
- return out_bboxes, out_labels
-
- def aug_test(self, imgs, img_metas, rescale=True):
- """Augment testing of CenterNet. Aug test must have flipped image pair,
- and unlike CornerNet, it will perform an averaging operation on the
- feature map instead of detecting bbox.
-
- Args:
- imgs (list[Tensor]): Augmented images.
- img_metas (list[list[dict]]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- rescale (bool): If True, return boxes in original image space.
- Default: True.
-
- Note:
- ``imgs`` must including flipped image pairs.
-
- Returns:
- list[list[np.ndarray]]: BBox results of each image and classes.
- The outer list corresponds to each image. The inner list
- corresponds to each class.
- """
- img_inds = list(range(len(imgs)))
- assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
- 'aug test must have flipped image pair')
- aug_results = []
- for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
- flip_direction = img_metas[flip_ind][0]['flip_direction']
- img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
- x = self.extract_feat(img_pair)
- center_heatmap_preds, wh_preds, offset_preds = self.bbox_head(x)
- assert len(center_heatmap_preds) == len(wh_preds) == len(
- offset_preds) == 1
-
- # Feature map averaging
- center_heatmap_preds[0] = (
- center_heatmap_preds[0][0:1] +
- flip_tensor(center_heatmap_preds[0][1:2], flip_direction)) / 2
- wh_preds[0] = (wh_preds[0][0:1] +
- flip_tensor(wh_preds[0][1:2], flip_direction)) / 2
-
- bbox_list = self.bbox_head.get_bboxes(
- center_heatmap_preds,
- wh_preds, [offset_preds[0][0:1]],
- img_metas[ind],
- rescale=rescale,
- with_nms=False)
- aug_results.append(bbox_list)
-
- nms_cfg = self.bbox_head.test_cfg.get('nms_cfg', None)
- if nms_cfg is None:
- with_nms = False
- else:
- with_nms = True
- bbox_list = [self.merge_aug_results(aug_results, with_nms)]
- bbox_results = [
- bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
- for det_bboxes, det_labels in bbox_list
- ]
- return bbox_results
|