|
- # Copyright (c) OpenMMLab. All rights reserved.
- from ..builder import HEADS
- from .standard_roi_head import StandardRoIHead
-
-
- @HEADS.register_module()
- class DoubleHeadRoIHead(StandardRoIHead):
- """RoI head for Double Head RCNN.
-
- https://arxiv.org/abs/1904.06493
- """
-
- def __init__(self, reg_roi_scale_factor, **kwargs):
- super(DoubleHeadRoIHead, self).__init__(**kwargs)
- self.reg_roi_scale_factor = reg_roi_scale_factor
-
- def _bbox_forward(self, x, rois):
- """Box head forward function used in both training and testing time."""
- bbox_cls_feats = self.bbox_roi_extractor(
- x[:self.bbox_roi_extractor.num_inputs], rois)
- bbox_reg_feats = self.bbox_roi_extractor(
- x[:self.bbox_roi_extractor.num_inputs],
- rois,
- roi_scale_factor=self.reg_roi_scale_factor)
- if self.with_shared_head:
- bbox_cls_feats = self.shared_head(bbox_cls_feats)
- bbox_reg_feats = self.shared_head(bbox_reg_feats)
- cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
-
- bbox_results = dict(
- cls_score=cls_score,
- bbox_pred=bbox_pred,
- bbox_feats=bbox_cls_feats)
- return bbox_results
|