|
- # Copyright (c) OpenMMLab. All rights reserved.
- from __future__ import division
- import copy
- import warnings
-
- import torch
- import torch.nn as nn
- from mmcv import ConfigDict
- from mmcv.ops import DeformConv2d, batched_nms
- from mmcv.runner import BaseModule, ModuleList
-
- from mmdet.core import (RegionAssigner, build_assigner, build_sampler,
- images_to_levels, multi_apply)
- from mmdet.core.utils import select_single_mlvl
- from ..builder import HEADS, build_head
- from .base_dense_head import BaseDenseHead
- from .rpn_head import RPNHead
-
-
- class AdaptiveConv(BaseModule):
- """AdaptiveConv used to adapt the sampling location with the anchors.
-
- Args:
- in_channels (int): Number of channels in the input image
- out_channels (int): Number of channels produced by the convolution
- kernel_size (int or tuple): Size of the conv kernel. Default: 3
- stride (int or tuple, optional): Stride of the convolution. Default: 1
- padding (int or tuple, optional): Zero-padding added to both sides of
- the input. Default: 1
- dilation (int or tuple, optional): Spacing between kernel elements.
- Default: 3
- groups (int, optional): Number of blocked connections from input
- channels to output channels. Default: 1
- bias (bool, optional): If set True, adds a learnable bias to the
- output. Default: False.
- type (str, optional): Type of adaptive conv, can be either 'offset'
- (arbitrary anchors) or 'dilation' (uniform anchor).
- Default: 'dilation'.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
-
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- dilation=3,
- groups=1,
- bias=False,
- type='dilation',
- init_cfg=dict(
- type='Normal', std=0.01, override=dict(name='conv'))):
- super(AdaptiveConv, self).__init__(init_cfg)
- assert type in ['offset', 'dilation']
- self.adapt_type = type
-
- assert kernel_size == 3, 'Adaptive conv only supports kernels 3'
- if self.adapt_type == 'offset':
- assert stride == 1 and padding == 1 and groups == 1, \
- 'Adaptive conv offset mode only supports padding: {1}, ' \
- f'stride: {1}, groups: {1}'
- self.conv = DeformConv2d(
- in_channels,
- out_channels,
- kernel_size,
- padding=padding,
- stride=stride,
- groups=groups,
- bias=bias)
- else:
- self.conv = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size,
- padding=dilation,
- dilation=dilation)
-
- def forward(self, x, offset):
- """Forward function."""
- if self.adapt_type == 'offset':
- N, _, H, W = x.shape
- assert offset is not None
- assert H * W == offset.shape[1]
- # reshape [N, NA, 18] to (N, 18, H, W)
- offset = offset.permute(0, 2, 1).reshape(N, -1, H, W)
- offset = offset.contiguous()
- x = self.conv(x, offset)
- else:
- assert offset is None
- x = self.conv(x)
- return x
-
-
- @HEADS.register_module()
- class StageCascadeRPNHead(RPNHead):
- """Stage of CascadeRPNHead.
-
- Args:
- in_channels (int): Number of channels in the input feature map.
- anchor_generator (dict): anchor generator config.
- adapt_cfg (dict): adaptation config.
- bridged_feature (bool, optional): whether update rpn feature.
- Default: False.
- with_cls (bool, optional): whether use classification branch.
- Default: True.
- sampling (bool, optional): whether use sampling. Default: True.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Default: None
- """
-
- def __init__(self,
- in_channels,
- anchor_generator=dict(
- type='AnchorGenerator',
- scales=[8],
- ratios=[1.0],
- strides=[4, 8, 16, 32, 64]),
- adapt_cfg=dict(type='dilation', dilation=3),
- bridged_feature=False,
- with_cls=True,
- sampling=True,
- init_cfg=None,
- **kwargs):
- self.with_cls = with_cls
- self.anchor_strides = anchor_generator['strides']
- self.anchor_scales = anchor_generator['scales']
- self.bridged_feature = bridged_feature
- self.adapt_cfg = adapt_cfg
- super(StageCascadeRPNHead, self).__init__(
- in_channels,
- anchor_generator=anchor_generator,
- init_cfg=init_cfg,
- **kwargs)
-
- # override sampling and sampler
- self.sampling = sampling
- if self.train_cfg:
- self.assigner = build_assigner(self.train_cfg.assigner)
- # use PseudoSampler when sampling is False
- if self.sampling and hasattr(self.train_cfg, 'sampler'):
- sampler_cfg = self.train_cfg.sampler
- else:
- sampler_cfg = dict(type='PseudoSampler')
- self.sampler = build_sampler(sampler_cfg, context=self)
-
- if init_cfg is None:
- self.init_cfg = dict(
- type='Normal', std=0.01, override=[dict(name='rpn_reg')])
- if self.with_cls:
- self.init_cfg['override'].append(dict(name='rpn_cls'))
-
- def _init_layers(self):
- """Init layers of a CascadeRPN stage."""
- self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels,
- **self.adapt_cfg)
- if self.with_cls:
- self.rpn_cls = nn.Conv2d(self.feat_channels,
- self.num_anchors * self.cls_out_channels,
- 1)
- self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
- self.relu = nn.ReLU(inplace=True)
-
- def forward_single(self, x, offset):
- """Forward function of single scale."""
- bridged_x = x
- x = self.relu(self.rpn_conv(x, offset))
- if self.bridged_feature:
- bridged_x = x # update feature
- cls_score = self.rpn_cls(x) if self.with_cls else None
- bbox_pred = self.rpn_reg(x)
- return bridged_x, cls_score, bbox_pred
-
- def forward(self, feats, offset_list=None):
- """Forward function."""
- if offset_list is None:
- offset_list = [None for _ in range(len(feats))]
- return multi_apply(self.forward_single, feats, offset_list)
-
- def _region_targets_single(self,
- anchors,
- valid_flags,
- gt_bboxes,
- gt_bboxes_ignore,
- gt_labels,
- img_meta,
- featmap_sizes,
- label_channels=1):
- """Get anchor targets based on region for single level."""
- assign_result = self.assigner.assign(
- anchors,
- valid_flags,
- gt_bboxes,
- img_meta,
- featmap_sizes,
- self.anchor_scales[0],
- self.anchor_strides,
- gt_bboxes_ignore=gt_bboxes_ignore,
- gt_labels=None,
- allowed_border=self.train_cfg.allowed_border)
- flat_anchors = torch.cat(anchors)
- sampling_result = self.sampler.sample(assign_result, flat_anchors,
- gt_bboxes)
-
- num_anchors = flat_anchors.shape[0]
- bbox_targets = torch.zeros_like(flat_anchors)
- bbox_weights = torch.zeros_like(flat_anchors)
- labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long)
- label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float)
-
- pos_inds = sampling_result.pos_inds
- neg_inds = sampling_result.neg_inds
- if len(pos_inds) > 0:
- if not self.reg_decoded_bbox:
- pos_bbox_targets = self.bbox_coder.encode(
- sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
- else:
- pos_bbox_targets = sampling_result.pos_gt_bboxes
- bbox_targets[pos_inds, :] = pos_bbox_targets
- bbox_weights[pos_inds, :] = 1.0
- if gt_labels is None:
- labels[pos_inds] = 1
- else:
- labels[pos_inds] = gt_labels[
- sampling_result.pos_assigned_gt_inds]
- if self.train_cfg.pos_weight <= 0:
- label_weights[pos_inds] = 1.0
- else:
- label_weights[pos_inds] = self.train_cfg.pos_weight
- if len(neg_inds) > 0:
- label_weights[neg_inds] = 1.0
-
- return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
- neg_inds)
-
- def region_targets(self,
- anchor_list,
- valid_flag_list,
- gt_bboxes_list,
- img_metas,
- featmap_sizes,
- gt_bboxes_ignore_list=None,
- gt_labels_list=None,
- label_channels=1,
- unmap_outputs=True):
- """See :func:`StageCascadeRPNHead.get_targets`."""
- num_imgs = len(img_metas)
- assert len(anchor_list) == len(valid_flag_list) == num_imgs
-
- # anchor number of multi levels
- num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
-
- # compute targets for each image
- if gt_bboxes_ignore_list is None:
- gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
- if gt_labels_list is None:
- gt_labels_list = [None for _ in range(num_imgs)]
- (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
- pos_inds_list, neg_inds_list) = multi_apply(
- self._region_targets_single,
- anchor_list,
- valid_flag_list,
- gt_bboxes_list,
- gt_bboxes_ignore_list,
- gt_labels_list,
- img_metas,
- featmap_sizes=featmap_sizes,
- label_channels=label_channels)
- # no valid anchors
- if any([labels is None for labels in all_labels]):
- return None
- # sampled anchors of all images
- num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
- num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
- # split targets to a list w.r.t. multiple levels
- labels_list = images_to_levels(all_labels, num_level_anchors)
- label_weights_list = images_to_levels(all_label_weights,
- num_level_anchors)
- bbox_targets_list = images_to_levels(all_bbox_targets,
- num_level_anchors)
- bbox_weights_list = images_to_levels(all_bbox_weights,
- num_level_anchors)
- return (labels_list, label_weights_list, bbox_targets_list,
- bbox_weights_list, num_total_pos, num_total_neg)
-
- def get_targets(self,
- anchor_list,
- valid_flag_list,
- gt_bboxes,
- img_metas,
- featmap_sizes,
- gt_bboxes_ignore=None,
- label_channels=1):
- """Compute regression and classification targets for anchors.
-
- Args:
- anchor_list (list[list]): Multi level anchors of each image.
- valid_flag_list (list[list]): Multi level valid flags of each
- image.
- gt_bboxes (list[Tensor]): Ground truth bboxes of each image.
- img_metas (list[dict]): Meta info of each image.
- featmap_sizes (list[Tensor]): Feature mapsize each level
- gt_bboxes_ignore (list[Tensor]): Ignore bboxes of each images
- label_channels (int): Channel of label.
-
- Returns:
- cls_reg_targets (tuple)
- """
- if isinstance(self.assigner, RegionAssigner):
- cls_reg_targets = self.region_targets(
- anchor_list,
- valid_flag_list,
- gt_bboxes,
- img_metas,
- featmap_sizes,
- gt_bboxes_ignore_list=gt_bboxes_ignore,
- label_channels=label_channels)
- else:
- cls_reg_targets = super(StageCascadeRPNHead, self).get_targets(
- anchor_list,
- valid_flag_list,
- gt_bboxes,
- img_metas,
- gt_bboxes_ignore_list=gt_bboxes_ignore,
- label_channels=label_channels)
- return cls_reg_targets
-
- def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes):
- """ Get offset for deformable conv based on anchor shape
- NOTE: currently support deformable kernel_size=3 and dilation=1
-
- Args:
- anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of
- multi-level anchors
- anchor_strides (list[int]): anchor stride of each level
-
- Returns:
- offset_list (list[tensor]): [NLVL, NA, 2, 18]: offset of DeformConv
- kernel.
- """
-
- def _shape_offset(anchors, stride, ks=3, dilation=1):
- # currently support kernel_size=3 and dilation=1
- assert ks == 3 and dilation == 1
- pad = (ks - 1) // 2
- idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device)
- yy, xx = torch.meshgrid(idx, idx) # return order matters
- xx = xx.reshape(-1)
- yy = yy.reshape(-1)
- w = (anchors[:, 2] - anchors[:, 0]) / stride
- h = (anchors[:, 3] - anchors[:, 1]) / stride
- w = w / (ks - 1) - dilation
- h = h / (ks - 1) - dilation
- offset_x = w[:, None] * xx # (NA, ks**2)
- offset_y = h[:, None] * yy # (NA, ks**2)
- return offset_x, offset_y
-
- def _ctr_offset(anchors, stride, featmap_size):
- feat_h, feat_w = featmap_size
- assert len(anchors) == feat_h * feat_w
-
- x = (anchors[:, 0] + anchors[:, 2]) * 0.5
- y = (anchors[:, 1] + anchors[:, 3]) * 0.5
- # compute centers on feature map
- x = x / stride
- y = y / stride
- # compute predefine centers
- xx = torch.arange(0, feat_w, device=anchors.device)
- yy = torch.arange(0, feat_h, device=anchors.device)
- yy, xx = torch.meshgrid(yy, xx)
- xx = xx.reshape(-1).type_as(x)
- yy = yy.reshape(-1).type_as(y)
-
- offset_x = x - xx # (NA, )
- offset_y = y - yy # (NA, )
- return offset_x, offset_y
-
- num_imgs = len(anchor_list)
- num_lvls = len(anchor_list[0])
- dtype = anchor_list[0][0].dtype
- device = anchor_list[0][0].device
- num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
-
- offset_list = []
- for i in range(num_imgs):
- mlvl_offset = []
- for lvl in range(num_lvls):
- c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl],
- anchor_strides[lvl],
- featmap_sizes[lvl])
- s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl],
- anchor_strides[lvl])
-
- # offset = ctr_offset + shape_offset
- offset_x = s_offset_x + c_offset_x[:, None]
- offset_y = s_offset_y + c_offset_y[:, None]
-
- # offset order (y0, x0, y1, x2, .., y8, x8, y9, x9)
- offset = torch.stack([offset_y, offset_x], dim=-1)
- offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2]
- mlvl_offset.append(offset)
- offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2]
- offset_list = images_to_levels(offset_list, num_level_anchors)
- return offset_list
-
- def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
- bbox_targets, bbox_weights, num_total_samples):
- """Loss function on single scale."""
- # classification loss
- if self.with_cls:
- labels = labels.reshape(-1)
- label_weights = label_weights.reshape(-1)
- cls_score = cls_score.permute(0, 2, 3,
- 1).reshape(-1, self.cls_out_channels)
- loss_cls = self.loss_cls(
- cls_score, labels, label_weights, avg_factor=num_total_samples)
- # regression loss
- bbox_targets = bbox_targets.reshape(-1, 4)
- bbox_weights = bbox_weights.reshape(-1, 4)
- bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
- if self.reg_decoded_bbox:
- # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
- # is applied directly on the decoded bounding boxes, it
- # decodes the already encoded coordinates to absolute format.
- anchors = anchors.reshape(-1, 4)
- bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
- loss_reg = self.loss_bbox(
- bbox_pred,
- bbox_targets,
- bbox_weights,
- avg_factor=num_total_samples)
- if self.with_cls:
- return loss_cls, loss_reg
- return None, loss_reg
-
- def loss(self,
- anchor_list,
- valid_flag_list,
- cls_scores,
- bbox_preds,
- gt_bboxes,
- img_metas,
- gt_bboxes_ignore=None):
- """Compute losses of the head.
-
- Args:
- anchor_list (list[list]): Multi level anchors of each image.
- cls_scores (list[Tensor]): Box scores for each scale level
- Has shape (N, num_anchors * num_classes, H, W)
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level with shape (N, num_anchors * 4, H, W)
- gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
- shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
- img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- gt_bboxes_ignore (None | list[Tensor]): specify which bounding
- boxes can be ignored when computing the loss. Default: None
-
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
- label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
- cls_reg_targets = self.get_targets(
- anchor_list,
- valid_flag_list,
- gt_bboxes,
- img_metas,
- featmap_sizes,
- gt_bboxes_ignore=gt_bboxes_ignore,
- label_channels=label_channels)
- if cls_reg_targets is None:
- return None
- (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
- num_total_pos, num_total_neg) = cls_reg_targets
- if self.sampling:
- num_total_samples = num_total_pos + num_total_neg
- else:
- # 200 is hard-coded average factor,
- # which follows guided anchoring.
- num_total_samples = sum([label.numel()
- for label in labels_list]) / 200.0
-
- # change per image, per level anchor_list to per_level, per_image
- mlvl_anchor_list = list(zip(*anchor_list))
- # concat mlvl_anchor_list
- mlvl_anchor_list = [
- torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list
- ]
-
- losses = multi_apply(
- self.loss_single,
- cls_scores,
- bbox_preds,
- mlvl_anchor_list,
- labels_list,
- label_weights_list,
- bbox_targets_list,
- bbox_weights_list,
- num_total_samples=num_total_samples)
- if self.with_cls:
- return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1])
- return dict(loss_rpn_reg=losses[1])
-
- def get_bboxes(self,
- anchor_list,
- cls_scores,
- bbox_preds,
- img_metas,
- cfg,
- rescale=False):
- """Get proposal predict.
-
- Args:
- anchor_list (list[list]): Multi level anchors of each image.
- cls_scores (list[Tensor]): Classification scores for all
- scale levels, each is a 4D-tensor, has shape
- (batch_size, num_priors * num_classes, H, W).
- bbox_preds (list[Tensor]): Box energies / deltas for all
- scale levels, each is a 4D-tensor, has shape
- (batch_size, num_priors * 4, H, W).
- img_metas (list[dict], Optional): Image meta info. Default None.
- cfg (mmcv.Config, Optional): Test / postprocessing configuration,
- if None, test_cfg would be used.
- rescale (bool): If True, return boxes in original image space.
- Default: False.
-
- Returns:
- Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
- are bounding box positions (tl_x, tl_y, br_x, br_y) and the
- 5-th column is a score between 0 and 1.
- """
- assert len(cls_scores) == len(bbox_preds)
-
- result_list = []
- for img_id in range(len(img_metas)):
- cls_score_list = select_single_mlvl(cls_scores, img_id)
- bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
- img_shape = img_metas[img_id]['img_shape']
- scale_factor = img_metas[img_id]['scale_factor']
- proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
- anchor_list[img_id], img_shape,
- scale_factor, cfg, rescale)
- result_list.append(proposals)
- return result_list
-
- def _get_bboxes_single(self,
- cls_scores,
- bbox_preds,
- mlvl_anchors,
- img_shape,
- scale_factor,
- cfg,
- rescale=False):
- """Transform outputs of a single image into bbox predictions.
-
- Args:
- cls_scores (list[Tensor]): Box scores from all scale
- levels of a single image, each item has shape
- (num_anchors * num_classes, H, W).
- bbox_preds (list[Tensor]): Box energies / deltas from
- all scale levels of a single image, each item has
- shape (num_anchors * 4, H, W).
- mlvl_anchors (list[Tensor]): Box reference from all scale
- levels of a single image, each item has shape
- (num_total_anchors, 4).
- img_shape (tuple[int]): Shape of the input image,
- (height, width, 3).
- scale_factor (ndarray): Scale factor of the image arange as
- (w_scale, h_scale, w_scale, h_scale).
- cfg (mmcv.Config): Test / postprocessing configuration,
- if None, test_cfg would be used.
- rescale (bool): If True, return boxes in original image space.
- Default False.
-
- Returns:
- Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
- are bounding box positions (tl_x, tl_y, br_x, br_y) and the
- 5-th column is a score between 0 and 1.
- """
- cfg = self.test_cfg if cfg is None else cfg
- cfg = copy.deepcopy(cfg)
- # bboxes from different level should be independent during NMS,
- # level_ids are used as labels for batched NMS to separate them
- level_ids = []
- mlvl_scores = []
- mlvl_bbox_preds = []
- mlvl_valid_anchors = []
- nms_pre = cfg.get('nms_pre', -1)
- for idx in range(len(cls_scores)):
- rpn_cls_score = cls_scores[idx]
- rpn_bbox_pred = bbox_preds[idx]
- assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
- rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
- if self.use_sigmoid_cls:
- rpn_cls_score = rpn_cls_score.reshape(-1)
- scores = rpn_cls_score.sigmoid()
- else:
- rpn_cls_score = rpn_cls_score.reshape(-1, 2)
- # We set FG labels to [0, num_class-1] and BG label to
- # num_class in RPN head since mmdet v2.5, which is unified to
- # be consistent with other head since mmdet v2.0. In mmdet v2.0
- # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
- scores = rpn_cls_score.softmax(dim=1)[:, 0]
- rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
- anchors = mlvl_anchors[idx]
-
- if 0 < nms_pre < scores.shape[0]:
- # sort is faster than topk
- # _, topk_inds = scores.topk(cfg.nms_pre)
- ranked_scores, rank_inds = scores.sort(descending=True)
- topk_inds = rank_inds[:nms_pre]
- scores = ranked_scores[:nms_pre]
- rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
- anchors = anchors[topk_inds, :]
- mlvl_scores.append(scores)
- mlvl_bbox_preds.append(rpn_bbox_pred)
- mlvl_valid_anchors.append(anchors)
- level_ids.append(
- scores.new_full((scores.size(0), ), idx, dtype=torch.long))
-
- scores = torch.cat(mlvl_scores)
- anchors = torch.cat(mlvl_valid_anchors)
- rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
- proposals = self.bbox_coder.decode(
- anchors, rpn_bbox_pred, max_shape=img_shape)
- ids = torch.cat(level_ids)
-
- if cfg.min_bbox_size >= 0:
- w = proposals[:, 2] - proposals[:, 0]
- h = proposals[:, 3] - proposals[:, 1]
- valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
- if not valid_mask.all():
- proposals = proposals[valid_mask]
- scores = scores[valid_mask]
- ids = ids[valid_mask]
-
- # deprecate arguments warning
- if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
- warnings.warn(
- 'In rpn_proposal or test_cfg, '
- 'nms_thr has been moved to a dict named nms as '
- 'iou_threshold, max_num has been renamed as max_per_img, '
- 'name of original arguments and the way to specify '
- 'iou_threshold of NMS will be deprecated.')
- if 'nms' not in cfg:
- cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
- if 'max_num' in cfg:
- if 'max_per_img' in cfg:
- assert cfg.max_num == cfg.max_per_img, f'You ' \
- f'set max_num and ' \
- f'max_per_img at the same time, but get {cfg.max_num} ' \
- f'and {cfg.max_per_img} respectively' \
- 'Please delete max_num which will be deprecated.'
- else:
- cfg.max_per_img = cfg.max_num
- if 'nms_thr' in cfg:
- assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \
- f' iou_threshold in nms and ' \
- f'nms_thr at the same time, but get' \
- f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \
- f' respectively. Please delete the nms_thr ' \
- f'which will be deprecated.'
-
- if proposals.numel() > 0:
- dets, _ = batched_nms(proposals, scores, ids, cfg.nms)
- else:
- return proposals.new_zeros(0, 5)
-
- return dets[:cfg.max_per_img]
-
- def refine_bboxes(self, anchor_list, bbox_preds, img_metas):
- """Refine bboxes through stages."""
- num_levels = len(bbox_preds)
- new_anchor_list = []
- for img_id in range(len(img_metas)):
- mlvl_anchors = []
- for i in range(num_levels):
- bbox_pred = bbox_preds[i][img_id].detach()
- bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
- img_shape = img_metas[img_id]['img_shape']
- bboxes = self.bbox_coder.decode(anchor_list[img_id][i],
- bbox_pred, img_shape)
- mlvl_anchors.append(bboxes)
- new_anchor_list.append(mlvl_anchors)
- return new_anchor_list
-
-
- @HEADS.register_module()
- class CascadeRPNHead(BaseDenseHead):
- """The CascadeRPNHead will predict more accurate region proposals, which is
- required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN
- consists of a sequence of RPNStage to progressively improve the accuracy of
- the detected proposals.
-
- More details can be found in ``https://arxiv.org/abs/1909.06720``.
-
- Args:
- num_stages (int): number of CascadeRPN stages.
- stages (list[dict]): list of configs to build the stages.
- train_cfg (list[dict]): list of configs at training time each stage.
- test_cfg (dict): config at testing time.
- """
-
- def __init__(self, num_stages, stages, train_cfg, test_cfg, init_cfg=None):
- super(CascadeRPNHead, self).__init__(init_cfg)
- assert num_stages == len(stages)
- self.num_stages = num_stages
- # Be careful! Pretrained weights cannot be loaded when use
- # nn.ModuleList
- self.stages = ModuleList()
- for i in range(len(stages)):
- train_cfg_i = train_cfg[i] if train_cfg is not None else None
- stages[i].update(train_cfg=train_cfg_i)
- stages[i].update(test_cfg=test_cfg)
- self.stages.append(build_head(stages[i]))
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
-
- def loss(self):
- """loss() is implemented in StageCascadeRPNHead."""
- pass
-
- def get_bboxes(self):
- """get_bboxes() is implemented in StageCascadeRPNHead."""
- pass
-
- def forward_train(self,
- x,
- img_metas,
- gt_bboxes,
- gt_labels=None,
- gt_bboxes_ignore=None,
- proposal_cfg=None):
- """Forward train function."""
- assert gt_labels is None, 'RPN does not require gt_labels'
-
- featmap_sizes = [featmap.size()[-2:] for featmap in x]
- device = x[0].device
- anchor_list, valid_flag_list = self.stages[0].get_anchors(
- featmap_sizes, img_metas, device=device)
-
- losses = dict()
-
- for i in range(self.num_stages):
- stage = self.stages[i]
-
- if stage.adapt_cfg['type'] == 'offset':
- offset_list = stage.anchor_offset(anchor_list,
- stage.anchor_strides,
- featmap_sizes)
- else:
- offset_list = None
- x, cls_score, bbox_pred = stage(x, offset_list)
- rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score,
- bbox_pred, gt_bboxes, img_metas)
- stage_loss = stage.loss(*rpn_loss_inputs)
- for name, value in stage_loss.items():
- losses['s{}.{}'.format(i, name)] = value
-
- # refine boxes
- if i < self.num_stages - 1:
- anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
- img_metas)
- if proposal_cfg is None:
- return losses
- else:
- proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
- bbox_pred, img_metas,
- self.test_cfg)
- return losses, proposal_list
-
- def simple_test_rpn(self, x, img_metas):
- """Simple forward test function."""
- featmap_sizes = [featmap.size()[-2:] for featmap in x]
- device = x[0].device
- anchor_list, _ = self.stages[0].get_anchors(
- featmap_sizes, img_metas, device=device)
-
- for i in range(self.num_stages):
- stage = self.stages[i]
- if stage.adapt_cfg['type'] == 'offset':
- offset_list = stage.anchor_offset(anchor_list,
- stage.anchor_strides,
- featmap_sizes)
- else:
- offset_list = None
- x, cls_score, bbox_pred = stage(x, offset_list)
- if i < self.num_stages - 1:
- anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
- img_metas)
-
- proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
- bbox_pred, img_metas,
- self.test_cfg)
- return proposal_list
-
- def aug_test_rpn(self, x, img_metas):
- """Augmented forward test function."""
- raise NotImplementedError(
- 'CascadeRPNHead does not support test-time augmentation')
|