|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm,
- normal_init)
- from mmcv.runner import force_fp32
-
- from mmdet.core import anchor_inside_flags, multi_apply, reduce_mean, unmap
- from ..builder import HEADS
- from .anchor_head import AnchorHead
-
- INF = 1e8
-
-
- def levels_to_images(mlvl_tensor):
- """Concat multi-level feature maps by image.
-
- [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
- Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
- (N, H*W , C), then split the element to N elements with shape (H*W, C), and
- concat elements in same image of all level along first dimension.
-
- Args:
- mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
- corresponding level. Each element is of shape (N, C, H, W)
-
- Returns:
- list[torch.Tensor]: A list that contains N tensors and each tensor is
- of shape (num_elements, C)
- """
- batch_size = mlvl_tensor[0].size(0)
- batch_list = [[] for _ in range(batch_size)]
- channels = mlvl_tensor[0].size(1)
- for t in mlvl_tensor:
- t = t.permute(0, 2, 3, 1)
- t = t.view(batch_size, -1, channels).contiguous()
- for img in range(batch_size):
- batch_list[img].append(t[img])
- return [torch.cat(item, 0) for item in batch_list]
-
-
- @HEADS.register_module()
- class YOLOFHead(AnchorHead):
- """YOLOFHead Paper link: https://arxiv.org/abs/2103.09460.
-
- Args:
- num_classes (int): The number of object classes (w/o background)
- in_channels (List[int]): The number of input channels per scale.
- cls_num_convs (int): The number of convolutions of cls branch.
- Default 2.
- reg_num_convs (int): The number of convolutions of reg branch.
- Default 4.
- norm_cfg (dict): Dictionary to construct and config norm layer.
- """
-
- def __init__(self,
- num_classes,
- in_channels,
- num_cls_convs=2,
- num_reg_convs=4,
- norm_cfg=dict(type='BN', requires_grad=True),
- **kwargs):
- self.num_cls_convs = num_cls_convs
- self.num_reg_convs = num_reg_convs
- self.norm_cfg = norm_cfg
- super(YOLOFHead, self).__init__(num_classes, in_channels, **kwargs)
-
- def _init_layers(self):
- cls_subnet = []
- bbox_subnet = []
- for i in range(self.num_cls_convs):
- cls_subnet.append(
- ConvModule(
- self.in_channels,
- self.in_channels,
- kernel_size=3,
- padding=1,
- norm_cfg=self.norm_cfg))
- for i in range(self.num_reg_convs):
- bbox_subnet.append(
- ConvModule(
- self.in_channels,
- self.in_channels,
- kernel_size=3,
- padding=1,
- norm_cfg=self.norm_cfg))
- self.cls_subnet = nn.Sequential(*cls_subnet)
- self.bbox_subnet = nn.Sequential(*bbox_subnet)
- self.cls_score = nn.Conv2d(
- self.in_channels,
- self.num_base_priors * self.num_classes,
- kernel_size=3,
- stride=1,
- padding=1)
- self.bbox_pred = nn.Conv2d(
- self.in_channels,
- self.num_base_priors * 4,
- kernel_size=3,
- stride=1,
- padding=1)
- self.object_pred = nn.Conv2d(
- self.in_channels,
- self.num_base_priors,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def init_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- normal_init(m, mean=0, std=0.01)
- if is_norm(m):
- constant_init(m, 1)
-
- # Use prior in model initialization to improve stability
- bias_cls = bias_init_with_prob(0.01)
- torch.nn.init.constant_(self.cls_score.bias, bias_cls)
-
- def forward_single(self, feature):
- cls_score = self.cls_score(self.cls_subnet(feature))
- N, _, H, W = cls_score.shape
- cls_score = cls_score.view(N, -1, self.num_classes, H, W)
-
- reg_feat = self.bbox_subnet(feature)
- bbox_reg = self.bbox_pred(reg_feat)
- objectness = self.object_pred(reg_feat)
-
- # implicit objectness
- objectness = objectness.view(N, -1, 1, H, W)
- normalized_cls_score = cls_score + objectness - torch.log(
- 1. + torch.clamp(cls_score.exp(), max=INF) +
- torch.clamp(objectness.exp(), max=INF))
- normalized_cls_score = normalized_cls_score.view(N, -1, H, W)
- return normalized_cls_score, bbox_reg
-
- @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
- def loss(self,
- cls_scores,
- bbox_preds,
- gt_bboxes,
- gt_labels,
- img_metas,
- gt_bboxes_ignore=None):
- """Compute losses of the head.
-
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level
- Has shape (batch, num_anchors * num_classes, h, w)
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level with shape (batch, num_anchors * 4, h, w)
- gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
- shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
- gt_labels (list[Tensor]): class indices corresponding to each box
- img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- gt_bboxes_ignore (None | list[Tensor]): specify which bounding
- boxes can be ignored when computing the loss. Default: None
-
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- assert len(cls_scores) == 1
- assert self.prior_generator.num_levels == 1
-
- device = cls_scores[0].device
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- anchor_list, valid_flag_list = self.get_anchors(
- featmap_sizes, img_metas, device=device)
-
- # The output level is always 1
- anchor_list = [anchors[0] for anchors in anchor_list]
- valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list]
-
- cls_scores_list = levels_to_images(cls_scores)
- bbox_preds_list = levels_to_images(bbox_preds)
-
- label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
- cls_reg_targets = self.get_targets(
- cls_scores_list,
- bbox_preds_list,
- anchor_list,
- valid_flag_list,
- gt_bboxes,
- img_metas,
- gt_bboxes_ignore_list=gt_bboxes_ignore,
- gt_labels_list=gt_labels,
- label_channels=label_channels)
- if cls_reg_targets is None:
- return None
- (batch_labels, batch_label_weights, num_total_pos, num_total_neg,
- batch_bbox_weights, batch_pos_predicted_boxes,
- batch_target_boxes) = cls_reg_targets
-
- flatten_labels = batch_labels.reshape(-1)
- batch_label_weights = batch_label_weights.reshape(-1)
- cls_score = cls_scores[0].permute(0, 2, 3,
- 1).reshape(-1, self.cls_out_channels)
-
- num_total_samples = (num_total_pos +
- num_total_neg) if self.sampling else num_total_pos
- num_total_samples = reduce_mean(
- cls_score.new_tensor(num_total_samples)).clamp_(1.0).item()
-
- # classification loss
- loss_cls = self.loss_cls(
- cls_score,
- flatten_labels,
- batch_label_weights,
- avg_factor=num_total_samples)
-
- # regression loss
- if batch_pos_predicted_boxes.shape[0] == 0:
- # no pos sample
- loss_bbox = batch_pos_predicted_boxes.sum() * 0
- else:
- loss_bbox = self.loss_bbox(
- batch_pos_predicted_boxes,
- batch_target_boxes,
- batch_bbox_weights.float(),
- avg_factor=num_total_samples)
-
- return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
-
- def get_targets(self,
- cls_scores_list,
- bbox_preds_list,
- anchor_list,
- valid_flag_list,
- gt_bboxes_list,
- img_metas,
- gt_bboxes_ignore_list=None,
- gt_labels_list=None,
- label_channels=1,
- unmap_outputs=True):
- """Compute regression and classification targets for anchors in
- multiple images.
-
- Args:
- cls_scores_list (list[Tensor]): Classification scores of
- each image. each is a 4D-tensor, the shape is
- (h * w, num_anchors * num_classes).
- bbox_preds_list (list[Tensor]): Bbox preds of each image.
- each is a 4D-tensor, the shape is (h * w, num_anchors * 4).
- anchor_list (list[Tensor]): Anchors of each image. Each element of
- is a tensor of shape (h * w * num_anchors, 4).
- valid_flag_list (list[Tensor]): Valid flags of each image. Each
- element of is a tensor of shape (h * w * num_anchors, )
- gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
- img_metas (list[dict]): Meta info of each image.
- gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
- ignored.
- gt_labels_list (list[Tensor]): Ground truth labels of each box.
- label_channels (int): Channel of label.
- unmap_outputs (bool): Whether to map outputs back to the original
- set of anchors.
-
- Returns:
- tuple: Usually returns a tuple containing learning targets.
-
- - batch_labels (Tensor): Label of all images. Each element \
- of is a tensor of shape (batch, h * w * num_anchors)
- - batch_label_weights (Tensor): Label weights of all images \
- of is a tensor of shape (batch, h * w * num_anchors)
- - num_total_pos (int): Number of positive samples in all \
- images.
- - num_total_neg (int): Number of negative samples in all \
- images.
- additional_returns: This function enables user-defined returns from
- `self._get_targets_single`. These returns are currently refined
- to properties at each feature map (i.e. having HxW dimension).
- The results will be concatenated after the end
- """
- num_imgs = len(img_metas)
- assert len(anchor_list) == len(valid_flag_list) == num_imgs
-
- # compute targets for each image
- if gt_bboxes_ignore_list is None:
- gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
- if gt_labels_list is None:
- gt_labels_list = [None for _ in range(num_imgs)]
- results = multi_apply(
- self._get_targets_single,
- bbox_preds_list,
- anchor_list,
- valid_flag_list,
- gt_bboxes_list,
- gt_bboxes_ignore_list,
- gt_labels_list,
- img_metas,
- label_channels=label_channels,
- unmap_outputs=unmap_outputs)
- (all_labels, all_label_weights, pos_inds_list, neg_inds_list,
- sampling_results_list) = results[:5]
- rest_results = list(results[5:]) # user-added return values
- # no valid anchors
- if any([labels is None for labels in all_labels]):
- return None
- # sampled anchors of all images
- num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
- num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
-
- batch_labels = torch.stack(all_labels, 0)
- batch_label_weights = torch.stack(all_label_weights, 0)
-
- res = (batch_labels, batch_label_weights, num_total_pos, num_total_neg)
- for i, rests in enumerate(rest_results): # user-added return values
- rest_results[i] = torch.cat(rests, 0)
-
- return res + tuple(rest_results)
-
- def _get_targets_single(self,
- bbox_preds,
- flat_anchors,
- valid_flags,
- gt_bboxes,
- gt_bboxes_ignore,
- gt_labels,
- img_meta,
- label_channels=1,
- unmap_outputs=True):
- """Compute regression and classification targets for anchors in a
- single image.
-
- Args:
- bbox_preds (Tensor): Bbox prediction of the image, which
- shape is (h * w ,4)
- flat_anchors (Tensor): Anchors of the image, which shape is
- (h * w * num_anchors ,4)
- valid_flags (Tensor): Valid flags of the image, which shape is
- (h * w * num_anchors,).
- gt_bboxes (Tensor): Ground truth bboxes of the image,
- shape (num_gts, 4).
- gt_bboxes_ignore (Tensor): Ground truth bboxes to be
- ignored, shape (num_ignored_gts, 4).
- img_meta (dict): Meta info of the image.
- gt_labels (Tensor): Ground truth labels of each box,
- shape (num_gts,).
- label_channels (int): Channel of label.
- unmap_outputs (bool): Whether to map outputs back to the original
- set of anchors.
-
- Returns:
- tuple:
- labels (Tensor): Labels of image, which shape is
- (h * w * num_anchors, ).
- label_weights (Tensor): Label weights of image, which shape is
- (h * w * num_anchors, ).
- pos_inds (Tensor): Pos index of image.
- neg_inds (Tensor): Neg index of image.
- sampling_result (obj:`SamplingResult`): Sampling result.
- pos_bbox_weights (Tensor): The Weight of using to calculate
- the bbox branch loss, which shape is (num, ).
- pos_predicted_boxes (Tensor): boxes predicted value of
- using to calculate the bbox branch loss, which shape is
- (num, 4).
- pos_target_boxes (Tensor): boxes target value of
- using to calculate the bbox branch loss, which shape is
- (num, 4).
- """
- inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
- img_meta['img_shape'][:2],
- self.train_cfg.allowed_border)
- if not inside_flags.any():
- return (None, ) * 8
- # assign gt and sample anchors
- anchors = flat_anchors[inside_flags, :]
- bbox_preds = bbox_preds.reshape(-1, 4)
- bbox_preds = bbox_preds[inside_flags, :]
-
- # decoded bbox
- decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds)
- assign_result = self.assigner.assign(
- decoder_bbox_preds, anchors, gt_bboxes, gt_bboxes_ignore,
- None if self.sampling else gt_labels)
-
- pos_bbox_weights = assign_result.get_extra_property('pos_idx')
- pos_predicted_boxes = assign_result.get_extra_property(
- 'pos_predicted_boxes')
- pos_target_boxes = assign_result.get_extra_property('target_boxes')
-
- sampling_result = self.sampler.sample(assign_result, anchors,
- gt_bboxes)
- num_valid_anchors = anchors.shape[0]
- labels = anchors.new_full((num_valid_anchors, ),
- self.num_classes,
- dtype=torch.long)
- label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
-
- pos_inds = sampling_result.pos_inds
- neg_inds = sampling_result.neg_inds
- if len(pos_inds) > 0:
- if gt_labels is None:
- # Only rpn gives gt_labels as None
- # Foreground is the first class since v2.5.0
- labels[pos_inds] = 0
- else:
- labels[pos_inds] = gt_labels[
- sampling_result.pos_assigned_gt_inds]
- if self.train_cfg.pos_weight <= 0:
- label_weights[pos_inds] = 1.0
- else:
- label_weights[pos_inds] = self.train_cfg.pos_weight
- if len(neg_inds) > 0:
- label_weights[neg_inds] = 1.0
-
- # map up to original set of anchors
- if unmap_outputs:
- num_total_anchors = flat_anchors.size(0)
- labels = unmap(
- labels, num_total_anchors, inside_flags,
- fill=self.num_classes) # fill bg label
- label_weights = unmap(label_weights, num_total_anchors,
- inside_flags)
-
- return (labels, label_weights, pos_inds, neg_inds, sampling_result,
- pos_bbox_weights, pos_predicted_boxes, pos_target_boxes)
|