|
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
-
- import mmcv
- import torch
- from mmcv.image import tensor2imgs
-
- from mmdet.core import bbox_mapping
- from ..builder import DETECTORS, build_backbone, build_head, build_neck
- from .base import BaseDetector
-
-
- @DETECTORS.register_module()
- class RPN(BaseDetector):
- """Implementation of Region Proposal Network."""
-
- def __init__(self,
- backbone,
- neck,
- rpn_head,
- train_cfg,
- test_cfg,
- pretrained=None,
- init_cfg=None):
- super(RPN, self).__init__(init_cfg)
- if pretrained:
- warnings.warn('DeprecationWarning: pretrained is deprecated, '
- 'please use "init_cfg" instead')
- backbone.pretrained = pretrained
- self.backbone = build_backbone(backbone)
- self.neck = build_neck(neck) if neck is not None else None
- rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
- rpn_head.update(train_cfg=rpn_train_cfg)
- rpn_head.update(test_cfg=test_cfg.rpn)
- self.rpn_head = build_head(rpn_head)
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
-
- def extract_feat(self, img):
- """Extract features.
-
- Args:
- img (torch.Tensor): Image tensor with shape (n, c, h ,w).
-
- Returns:
- list[torch.Tensor]: Multi-level features that may have
- different resolutions.
- """
- x = self.backbone(img)
- if self.with_neck:
- x = self.neck(x)
- return x
-
- def forward_dummy(self, img):
- """Dummy forward function."""
- x = self.extract_feat(img)
- rpn_outs = self.rpn_head(x)
- return rpn_outs
-
- def forward_train(self,
- img,
- img_metas,
- gt_bboxes=None,
- gt_bboxes_ignore=None):
- """
- Args:
- img (Tensor): Input images of shape (N, C, H, W).
- Typically these should be mean centered and std scaled.
- img_metas (list[dict]): A List of image info dict where each dict
- has: 'img_shape', 'scale_factor', 'flip', and may also contain
- 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
- For details on the values of these keys see
- :class:`mmdet.datasets.pipelines.Collect`.
- gt_bboxes (list[Tensor]): Each item are the truth boxes for each
- image in [tl_x, tl_y, br_x, br_y] format.
- gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
- boxes can be ignored when computing the loss.
-
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- if (isinstance(self.train_cfg.rpn, dict)
- and self.train_cfg.rpn.get('debug', False)):
- self.rpn_head.debug_imgs = tensor2imgs(img)
-
- x = self.extract_feat(img)
- losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None,
- gt_bboxes_ignore)
- return losses
-
- def simple_test(self, img, img_metas, rescale=False):
- """Test function without test time augmentation.
-
- Args:
- imgs (list[torch.Tensor]): List of multiple images
- img_metas (list[dict]): List of image information.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
-
- Returns:
- list[np.ndarray]: proposals
- """
- x = self.extract_feat(img)
- # get origin input shape to onnx dynamic input shape
- if torch.onnx.is_in_onnx_export():
- img_shape = torch._shape_as_tensor(img)[2:]
- img_metas[0]['img_shape_for_onnx'] = img_shape
- proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
- if rescale:
- for proposals, meta in zip(proposal_list, img_metas):
- proposals[:, :4] /= proposals.new_tensor(meta['scale_factor'])
- if torch.onnx.is_in_onnx_export():
- return proposal_list
-
- return [proposal.cpu().numpy() for proposal in proposal_list]
-
- def aug_test(self, imgs, img_metas, rescale=False):
- """Test function with test time augmentation.
-
- Args:
- imgs (list[torch.Tensor]): List of multiple images
- img_metas (list[dict]): List of image information.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
-
- Returns:
- list[np.ndarray]: proposals
- """
- proposal_list = self.rpn_head.aug_test_rpn(
- self.extract_feats(imgs), img_metas)
- if not rescale:
- for proposals, img_meta in zip(proposal_list, img_metas[0]):
- img_shape = img_meta['img_shape']
- scale_factor = img_meta['scale_factor']
- flip = img_meta['flip']
- flip_direction = img_meta['flip_direction']
- proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
- scale_factor, flip,
- flip_direction)
- return [proposal.cpu().numpy() for proposal in proposal_list]
-
- def show_result(self, data, result, top_k=20, **kwargs):
- """Show RPN proposals on the image.
-
- Args:
- data (str or np.ndarray): Image filename or loaded image.
- result (Tensor or tuple): The results to draw over `img`
- bbox_result or (bbox_result, segm_result).
- top_k (int): Plot the first k bboxes only
- if set positive. Default: 20
-
- Returns:
- np.ndarray: The image with bboxes drawn on it.
- """
- if kwargs is not None:
- kwargs.pop('score_thr', None)
- kwargs.pop('text_color', None)
- kwargs['colors'] = kwargs.pop('bbox_color', 'green')
- mmcv.imshow_bboxes(data, result, top_k=top_k, **kwargs)
|