|
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
- from abc import abstractmethod
-
- import torch
- import torch.nn as nn
- from mmcv.cnn import ConvModule
- from mmcv.runner import force_fp32
-
- from mmdet.core import build_bbox_coder, multi_apply
- from mmdet.core.anchor.point_generator import MlvlPointGenerator
- from ..builder import HEADS, build_loss
- from .base_dense_head import BaseDenseHead
- from .dense_test_mixins import BBoxTestMixin
-
-
- @HEADS.register_module()
- class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
- """Anchor-free head (FCOS, Fovea, RepPoints, etc.).
-
- Args:
- num_classes (int): Number of categories excluding the background
- category.
- in_channels (int): Number of channels in the input feature map.
- feat_channels (int): Number of hidden channels. Used in child classes.
- stacked_convs (int): Number of stacking convs of the head.
- strides (tuple): Downsample factor of each feature map.
- dcn_on_last_conv (bool): If true, use dcn in the last layer of
- towers. Default: False.
- conv_bias (bool | str): If specified as `auto`, it will be decided by
- the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
- None, otherwise False. Default: "auto".
- loss_cls (dict): Config of classification loss.
- loss_bbox (dict): Config of localization loss.
- bbox_coder (dict): Config of bbox coder. Defaults
- 'DistancePointBBoxCoder'.
- conv_cfg (dict): Config dict for convolution layer. Default: None.
- norm_cfg (dict): Config dict for normalization layer. Default: None.
- train_cfg (dict): Training config of anchor head.
- test_cfg (dict): Testing config of anchor head.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """ # noqa: W605
-
- _version = 1
-
- def __init__(self,
- num_classes,
- in_channels,
- feat_channels=256,
- stacked_convs=4,
- strides=(4, 8, 16, 32, 64),
- dcn_on_last_conv=False,
- conv_bias='auto',
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox=dict(type='IoULoss', loss_weight=1.0),
- bbox_coder=dict(type='DistancePointBBoxCoder'),
- conv_cfg=None,
- norm_cfg=None,
- train_cfg=None,
- test_cfg=None,
- init_cfg=dict(
- type='Normal',
- layer='Conv2d',
- std=0.01,
- override=dict(
- type='Normal',
- name='conv_cls',
- std=0.01,
- bias_prob=0.01))):
- super(AnchorFreeHead, self).__init__(init_cfg)
- self.num_classes = num_classes
- self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
- if self.use_sigmoid_cls:
- self.cls_out_channels = num_classes
- else:
- self.cls_out_channels = num_classes + 1
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.stacked_convs = stacked_convs
- self.strides = strides
- self.dcn_on_last_conv = dcn_on_last_conv
- assert conv_bias == 'auto' or isinstance(conv_bias, bool)
- self.conv_bias = conv_bias
- self.loss_cls = build_loss(loss_cls)
- self.loss_bbox = build_loss(loss_bbox)
- self.bbox_coder = build_bbox_coder(bbox_coder)
-
- self.prior_generator = MlvlPointGenerator(strides)
-
- # In order to keep a more general interface and be consistent with
- # anchor_head. We can think of point like one anchor
- self.num_base_priors = self.prior_generator.num_base_priors[0]
-
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self.fp16_enabled = False
-
- self._init_layers()
-
- def _init_layers(self):
- """Initialize layers of the head."""
- self._init_cls_convs()
- self._init_reg_convs()
- self._init_predictor()
-
- def _init_cls_convs(self):
- """Initialize classification conv layers of the head."""
- self.cls_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels if i == 0 else self.feat_channels
- if self.dcn_on_last_conv and i == self.stacked_convs - 1:
- conv_cfg = dict(type='DCNv2')
- else:
- conv_cfg = self.conv_cfg
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=self.norm_cfg,
- bias=self.conv_bias))
-
- def _init_reg_convs(self):
- """Initialize bbox regression conv layers of the head."""
- self.reg_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels if i == 0 else self.feat_channels
- if self.dcn_on_last_conv and i == self.stacked_convs - 1:
- conv_cfg = dict(type='DCNv2')
- else:
- conv_cfg = self.conv_cfg
- self.reg_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=self.norm_cfg,
- bias=self.conv_bias))
-
- def _init_predictor(self):
- """Initialize predictor layers of the head."""
- self.conv_cls = nn.Conv2d(
- self.feat_channels, self.cls_out_channels, 3, padding=1)
- self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
-
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- """Hack some keys of the model state dict so that can load checkpoints
- of previous version."""
- version = local_metadata.get('version', None)
- if version is None:
- # the key is different in early versions
- # for example, 'fcos_cls' become 'conv_cls' now
- bbox_head_keys = [
- k for k in state_dict.keys() if k.startswith(prefix)
- ]
- ori_predictor_keys = []
- new_predictor_keys = []
- # e.g. 'fcos_cls' or 'fcos_reg'
- for key in bbox_head_keys:
- ori_predictor_keys.append(key)
- key = key.split('.')
- conv_name = None
- if key[1].endswith('cls'):
- conv_name = 'conv_cls'
- elif key[1].endswith('reg'):
- conv_name = 'conv_reg'
- elif key[1].endswith('centerness'):
- conv_name = 'conv_centerness'
- else:
- assert NotImplementedError
- if conv_name is not None:
- key[1] = conv_name
- new_predictor_keys.append('.'.join(key))
- else:
- ori_predictor_keys.pop(-1)
- for i in range(len(new_predictor_keys)):
- state_dict[new_predictor_keys[i]] = state_dict.pop(
- ori_predictor_keys[i])
- super()._load_from_state_dict(state_dict, prefix, local_metadata,
- strict, missing_keys, unexpected_keys,
- error_msgs)
-
- def forward(self, feats):
- """Forward features from the upstream network.
-
- Args:
- feats (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
-
- Returns:
- tuple: Usually contain classification scores and bbox predictions.
- cls_scores (list[Tensor]): Box scores for each scale level,
- each is a 4D-tensor, the channel number is
- num_points * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level, each is a 4D-tensor, the channel number is
- num_points * 4.
- """
- return multi_apply(self.forward_single, feats)[:2]
-
- def forward_single(self, x):
- """Forward features of a single scale level.
-
- Args:
- x (Tensor): FPN feature maps of the specified stride.
-
- Returns:
- tuple: Scores for each class, bbox predictions, features
- after classification and regression conv layers, some
- models needs these features like FCOS.
- """
- cls_feat = x
- reg_feat = x
-
- for cls_layer in self.cls_convs:
- cls_feat = cls_layer(cls_feat)
- cls_score = self.conv_cls(cls_feat)
-
- for reg_layer in self.reg_convs:
- reg_feat = reg_layer(reg_feat)
- bbox_pred = self.conv_reg(reg_feat)
- return cls_score, bbox_pred, cls_feat, reg_feat
-
- @abstractmethod
- @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
- def loss(self,
- cls_scores,
- bbox_preds,
- gt_bboxes,
- gt_labels,
- img_metas,
- gt_bboxes_ignore=None):
- """Compute loss of the head.
-
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level,
- each is a 4D-tensor, the channel number is
- num_points * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level, each is a 4D-tensor, the channel number is
- num_points * 4.
- gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
- shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
- gt_labels (list[Tensor]): class indices corresponding to each box
- img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- gt_bboxes_ignore (None | list[Tensor]): specify which bounding
- boxes can be ignored when computing the loss.
- """
-
- raise NotImplementedError
-
- @abstractmethod
- def get_targets(self, points, gt_bboxes_list, gt_labels_list):
- """Compute regression, classification and centerness targets for points
- in multiple images.
-
- Args:
- points (list[Tensor]): Points of each fpn level, each has shape
- (num_points, 2).
- gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
- each has shape (num_gt, 4).
- gt_labels_list (list[Tensor]): Ground truth labels of each box,
- each has shape (num_gt,).
- """
- raise NotImplementedError
-
- def _get_points_single(self,
- featmap_size,
- stride,
- dtype,
- device,
- flatten=False):
- """Get points of a single scale level.
-
- This function will be deprecated soon.
- """
-
- warnings.warn(
- '`_get_points_single` in `AnchorFreeHead` will be '
- 'deprecated soon, we support a multi level point generator now'
- 'you can get points of a single level feature map '
- 'with `self.prior_generator.single_level_grid_priors` ')
-
- h, w = featmap_size
- # First create Range with the default dtype, than convert to
- # target `dtype` for onnx exporting.
- x_range = torch.arange(w, device=device).to(dtype)
- y_range = torch.arange(h, device=device).to(dtype)
- y, x = torch.meshgrid(y_range, x_range)
- if flatten:
- y = y.flatten()
- x = x.flatten()
- return y, x
-
- def get_points(self, featmap_sizes, dtype, device, flatten=False):
- """Get points according to feature map sizes.
-
- Args:
- featmap_sizes (list[tuple]): Multi-level feature map sizes.
- dtype (torch.dtype): Type of points.
- device (torch.device): Device of points.
-
- Returns:
- tuple: points of each image.
- """
- warnings.warn(
- '`get_points` in `AnchorFreeHead` will be '
- 'deprecated soon, we support a multi level point generator now'
- 'you can get points of all levels '
- 'with `self.prior_generator.grid_priors` ')
-
- mlvl_points = []
- for i in range(len(featmap_sizes)):
- mlvl_points.append(
- self._get_points_single(featmap_sizes[i], self.strides[i],
- dtype, device, flatten))
- return mlvl_points
-
- def aug_test(self, feats, img_metas, rescale=False):
- """Test function with test time augmentation.
-
- Args:
- feats (list[Tensor]): the outer list indicates test-time
- augmentations and inner Tensor should have a shape NxCxHxW,
- which contains features for all images in the batch.
- img_metas (list[list[dict]]): the outer list indicates test-time
- augs (multiscale, flip, etc.) and the inner list indicates
- images in a batch. each dict has image information.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
-
- Returns:
- list[ndarray]: bbox results of each class
- """
- return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
|