|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- from mmcv.ops import batched_nms
-
- from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
- multiclass_nms)
- from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
- from ..builder import HEADS
-
-
- @HEADS.register_module()
- class TridentRoIHead(StandardRoIHead):
- """Trident roi head.
-
- Args:
- num_branch (int): Number of branches in TridentNet.
- test_branch_idx (int): In inference, all 3 branches will be used
- if `test_branch_idx==-1`, otherwise only branch with index
- `test_branch_idx` will be used.
- """
-
- def __init__(self, num_branch, test_branch_idx, **kwargs):
- self.num_branch = num_branch
- self.test_branch_idx = test_branch_idx
- super(TridentRoIHead, self).__init__(**kwargs)
-
- def merge_trident_bboxes(self, trident_det_bboxes, trident_det_labels):
- """Merge bbox predictions of each branch."""
- if trident_det_bboxes.numel() == 0:
- det_bboxes = trident_det_bboxes.new_zeros((0, 5))
- det_labels = trident_det_bboxes.new_zeros((0, ), dtype=torch.long)
- else:
- nms_bboxes = trident_det_bboxes[:, :4]
- nms_scores = trident_det_bboxes[:, 4].contiguous()
- nms_inds = trident_det_labels
- nms_cfg = self.test_cfg['nms']
- det_bboxes, keep = batched_nms(nms_bboxes, nms_scores, nms_inds,
- nms_cfg)
- det_labels = trident_det_labels[keep]
- if self.test_cfg['max_per_img'] > 0:
- det_labels = det_labels[:self.test_cfg['max_per_img']]
- det_bboxes = det_bboxes[:self.test_cfg['max_per_img']]
-
- return det_bboxes, det_labels
-
- def simple_test(self,
- x,
- proposal_list,
- img_metas,
- proposals=None,
- rescale=False):
- """Test without augmentation as follows:
-
- 1. Compute prediction bbox and label per branch.
- 2. Merge predictions of each branch according to scores of
- bboxes, i.e., bboxes with higher score are kept to give
- top-k prediction.
- """
- assert self.with_bbox, 'Bbox head must be implemented.'
- det_bboxes_list, det_labels_list = self.simple_test_bboxes(
- x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
- num_branch = self.num_branch if self.test_branch_idx == -1 else 1
- for _ in range(len(det_bboxes_list)):
- if det_bboxes_list[_].shape[0] == 0:
- det_bboxes_list[_] = det_bboxes_list[_].new_empty((0, 5))
- det_bboxes, det_labels = [], []
- for i in range(len(img_metas) // num_branch):
- det_result = self.merge_trident_bboxes(
- torch.cat(det_bboxes_list[i * num_branch:(i + 1) *
- num_branch]),
- torch.cat(det_labels_list[i * num_branch:(i + 1) *
- num_branch]))
- det_bboxes.append(det_result[0])
- det_labels.append(det_result[1])
-
- bbox_results = [
- bbox2result(det_bboxes[i], det_labels[i],
- self.bbox_head.num_classes)
- for i in range(len(det_bboxes))
- ]
- return bbox_results
-
- def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
- """Test det bboxes with test time augmentation."""
- aug_bboxes = []
- aug_scores = []
- for x, img_meta in zip(feats, img_metas):
- # only one image in the batch
- img_shape = img_meta[0]['img_shape']
- scale_factor = img_meta[0]['scale_factor']
- flip = img_meta[0]['flip']
- flip_direction = img_meta[0]['flip_direction']
-
- trident_bboxes, trident_scores = [], []
- for branch_idx in range(len(proposal_list)):
- proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
- scale_factor, flip, flip_direction)
- rois = bbox2roi([proposals])
- bbox_results = self._bbox_forward(x, rois)
- bboxes, scores = self.bbox_head.get_bboxes(
- rois,
- bbox_results['cls_score'],
- bbox_results['bbox_pred'],
- img_shape,
- scale_factor,
- rescale=False,
- cfg=None)
- trident_bboxes.append(bboxes)
- trident_scores.append(scores)
-
- aug_bboxes.append(torch.cat(trident_bboxes, 0))
- aug_scores.append(torch.cat(trident_scores, 0))
- # after merging, bboxes will be rescaled to the original image size
- merged_bboxes, merged_scores = merge_aug_bboxes(
- aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
- det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
- rcnn_test_cfg.score_thr,
- rcnn_test_cfg.nms,
- rcnn_test_cfg.max_per_img)
- return det_bboxes, det_labels
|