|
- # Copyright (c) OpenMMLab. All rights reserved.
- from ..builder import DETECTORS
- from .faster_rcnn import FasterRCNN
-
-
- @DETECTORS.register_module()
- class TridentFasterRCNN(FasterRCNN):
- """Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_"""
-
- def __init__(self,
- backbone,
- rpn_head,
- roi_head,
- train_cfg,
- test_cfg,
- neck=None,
- pretrained=None,
- init_cfg=None):
-
- super(TridentFasterRCNN, self).__init__(
- backbone=backbone,
- neck=neck,
- rpn_head=rpn_head,
- roi_head=roi_head,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- pretrained=pretrained,
- init_cfg=init_cfg)
- assert self.backbone.num_branch == self.roi_head.num_branch
- assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
- self.num_branch = self.backbone.num_branch
- self.test_branch_idx = self.backbone.test_branch_idx
-
- def simple_test(self, img, img_metas, proposals=None, rescale=False):
- """Test without augmentation."""
- assert self.with_bbox, 'Bbox head must be implemented.'
- x = self.extract_feat(img)
- if proposals is None:
- num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
- trident_img_metas = img_metas * num_branch
- proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas)
- else:
- proposal_list = proposals
- # TODO: Fix trident_img_metas undefined errors
- # when proposals is specified
- return self.roi_head.simple_test(
- x, proposal_list, trident_img_metas, rescale=rescale)
-
- def aug_test(self, imgs, img_metas, rescale=False):
- """Test with augmentations.
-
- If rescale is False, then returned bboxes and masks will fit the scale
- of imgs[0].
- """
- x = self.extract_feats(imgs)
- num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
- trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
- proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
- return self.roi_head.aug_test(
- x, proposal_list, img_metas, rescale=rescale)
-
- def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
- """make copies of img and gts to fit multi-branch."""
- trident_gt_bboxes = tuple(gt_bboxes * self.num_branch)
- trident_gt_labels = tuple(gt_labels * self.num_branch)
- trident_img_metas = tuple(img_metas * self.num_branch)
-
- return super(TridentFasterRCNN,
- self).forward_train(img, trident_img_metas,
- trident_gt_bboxes, trident_gt_labels)
|