|
- # Copyright (c) OpenMMLab. All rights reserved.
- from ..builder import DETECTORS
- from .two_stage import TwoStageDetector
-
-
- @DETECTORS.register_module()
- class FastRCNN(TwoStageDetector):
- """Implementation of `Fast R-CNN <https://arxiv.org/abs/1504.08083>`_"""
-
- def __init__(self,
- backbone,
- roi_head,
- train_cfg,
- test_cfg,
- neck=None,
- pretrained=None,
- init_cfg=None):
- super(FastRCNN, self).__init__(
- backbone=backbone,
- neck=neck,
- roi_head=roi_head,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- pretrained=pretrained,
- init_cfg=init_cfg)
-
- def forward_test(self, imgs, img_metas, proposals, **kwargs):
- """
- Args:
- imgs (List[Tensor]): the outer list indicates test-time
- augmentations and inner Tensor should have a shape NxCxHxW,
- which contains all images in the batch.
- img_metas (List[List[dict]]): the outer list indicates test-time
- augs (multiscale, flip, etc.) and the inner list indicates
- images in a batch.
- proposals (List[List[Tensor]]): the outer list indicates test-time
- augs (multiscale, flip, etc.) and the inner list indicates
- images in a batch. The Tensor should have a shape Px4, where
- P is the number of proposals.
- """
- for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
- if not isinstance(var, list):
- raise TypeError(f'{name} must be a list, but got {type(var)}')
-
- num_augs = len(imgs)
- if num_augs != len(img_metas):
- raise ValueError(f'num of augmentations ({len(imgs)}) '
- f'!= num of image meta ({len(img_metas)})')
-
- if num_augs == 1:
- return self.simple_test(imgs[0], img_metas[0], proposals[0],
- **kwargs)
- else:
- # TODO: support test-time augmentation
- assert NotImplementedError
|