|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
-
- from mmdet.core import bbox2result, bbox_mapping_back
- from ..builder import DETECTORS
- from .single_stage import SingleStageDetector
-
-
- @DETECTORS.register_module()
- class CornerNet(SingleStageDetector):
- """CornerNet.
-
- This detector is the implementation of the paper `CornerNet: Detecting
- Objects as Paired Keypoints <https://arxiv.org/abs/1808.01244>`_ .
- """
-
- def __init__(self,
- backbone,
- neck,
- bbox_head,
- train_cfg=None,
- test_cfg=None,
- pretrained=None,
- init_cfg=None):
- super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg,
- test_cfg, pretrained, init_cfg)
-
- def merge_aug_results(self, aug_results, img_metas):
- """Merge augmented detection bboxes and score.
-
- Args:
- aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
- image.
- img_metas (list[list[dict]]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
-
- Returns:
- tuple: (bboxes, labels)
- """
- recovered_bboxes, aug_labels = [], []
- for bboxes_labels, img_info in zip(aug_results, img_metas):
- img_shape = img_info[0]['img_shape'] # using shape before padding
- scale_factor = img_info[0]['scale_factor']
- flip = img_info[0]['flip']
- bboxes, labels = bboxes_labels
- bboxes, scores = bboxes[:, :4], bboxes[:, -1:]
- bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip)
- recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1))
- aug_labels.append(labels)
-
- bboxes = torch.cat(recovered_bboxes, dim=0)
- labels = torch.cat(aug_labels)
-
- if bboxes.shape[0] > 0:
- out_bboxes, out_labels = self.bbox_head._bboxes_nms(
- bboxes, labels, self.bbox_head.test_cfg)
- else:
- out_bboxes, out_labels = bboxes, labels
-
- return out_bboxes, out_labels
-
- def aug_test(self, imgs, img_metas, rescale=False):
- """Augment testing of CornerNet.
-
- Args:
- imgs (list[Tensor]): Augmented images.
- img_metas (list[list[dict]]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- rescale (bool): If True, return boxes in original image space.
- Default: False.
-
- Note:
- ``imgs`` must including flipped image pairs.
-
- Returns:
- list[list[np.ndarray]]: BBox results of each image and classes.
- The outer list corresponds to each image. The inner list
- corresponds to each class.
- """
- img_inds = list(range(len(imgs)))
-
- assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
- 'aug test must have flipped image pair')
- aug_results = []
- for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
- img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
- x = self.extract_feat(img_pair)
- outs = self.bbox_head(x)
- bbox_list = self.bbox_head.get_bboxes(
- *outs, [img_metas[ind], img_metas[flip_ind]], False, False)
- aug_results.append(bbox_list[0])
- aug_results.append(bbox_list[1])
-
- bboxes, labels = self.merge_aug_results(aug_results, img_metas)
- bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes)
-
- return [bbox_results]
|