|
- # Copyright (c) OpenMMLab. All rights reserved.
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmcv.runner import BaseModule
-
- from mmdet.models.builder import HEADS, build_loss
-
-
- @HEADS.register_module()
- class GridHead(BaseModule):
-
- def __init__(self,
- grid_points=9,
- num_convs=8,
- roi_feat_size=14,
- in_channels=256,
- conv_kernel_size=3,
- point_feat_channels=64,
- deconv_kernel_size=4,
- class_agnostic=False,
- loss_grid=dict(
- type='CrossEntropyLoss', use_sigmoid=True,
- loss_weight=15),
- conv_cfg=None,
- norm_cfg=dict(type='GN', num_groups=36),
- init_cfg=[
- dict(type='Kaiming', layer=['Conv2d', 'Linear']),
- dict(
- type='Normal',
- layer='ConvTranspose2d',
- std=0.001,
- override=dict(
- type='Normal',
- name='deconv2',
- std=0.001,
- bias=-np.log(0.99 / 0.01)))
- ]):
- super(GridHead, self).__init__(init_cfg)
- self.grid_points = grid_points
- self.num_convs = num_convs
- self.roi_feat_size = roi_feat_size
- self.in_channels = in_channels
- self.conv_kernel_size = conv_kernel_size
- self.point_feat_channels = point_feat_channels
- self.conv_out_channels = self.point_feat_channels * self.grid_points
- self.class_agnostic = class_agnostic
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN':
- assert self.conv_out_channels % norm_cfg['num_groups'] == 0
-
- assert self.grid_points >= 4
- self.grid_size = int(np.sqrt(self.grid_points))
- if self.grid_size * self.grid_size != self.grid_points:
- raise ValueError('grid_points must be a square number')
-
- # the predicted heatmap is half of whole_map_size
- if not isinstance(self.roi_feat_size, int):
- raise ValueError('Only square RoIs are supporeted in Grid R-CNN')
- self.whole_map_size = self.roi_feat_size * 4
-
- # compute point-wise sub-regions
- self.sub_regions = self.calc_sub_regions()
-
- self.convs = []
- for i in range(self.num_convs):
- in_channels = (
- self.in_channels if i == 0 else self.conv_out_channels)
- stride = 2 if i == 0 else 1
- padding = (self.conv_kernel_size - 1) // 2
- self.convs.append(
- ConvModule(
- in_channels,
- self.conv_out_channels,
- self.conv_kernel_size,
- stride=stride,
- padding=padding,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg,
- bias=True))
- self.convs = nn.Sequential(*self.convs)
-
- self.deconv1 = nn.ConvTranspose2d(
- self.conv_out_channels,
- self.conv_out_channels,
- kernel_size=deconv_kernel_size,
- stride=2,
- padding=(deconv_kernel_size - 2) // 2,
- groups=grid_points)
- self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels)
- self.deconv2 = nn.ConvTranspose2d(
- self.conv_out_channels,
- grid_points,
- kernel_size=deconv_kernel_size,
- stride=2,
- padding=(deconv_kernel_size - 2) // 2,
- groups=grid_points)
-
- # find the 4-neighbor of each grid point
- self.neighbor_points = []
- grid_size = self.grid_size
- for i in range(grid_size): # i-th column
- for j in range(grid_size): # j-th row
- neighbors = []
- if i > 0: # left: (i - 1, j)
- neighbors.append((i - 1) * grid_size + j)
- if j > 0: # up: (i, j - 1)
- neighbors.append(i * grid_size + j - 1)
- if j < grid_size - 1: # down: (i, j + 1)
- neighbors.append(i * grid_size + j + 1)
- if i < grid_size - 1: # right: (i + 1, j)
- neighbors.append((i + 1) * grid_size + j)
- self.neighbor_points.append(tuple(neighbors))
- # total edges in the grid
- self.num_edges = sum([len(p) for p in self.neighbor_points])
-
- self.forder_trans = nn.ModuleList() # first-order feature transition
- self.sorder_trans = nn.ModuleList() # second-order feature transition
- for neighbors in self.neighbor_points:
- fo_trans = nn.ModuleList()
- so_trans = nn.ModuleList()
- for _ in range(len(neighbors)):
- # each transition module consists of a 5x5 depth-wise conv and
- # 1x1 conv.
- fo_trans.append(
- nn.Sequential(
- nn.Conv2d(
- self.point_feat_channels,
- self.point_feat_channels,
- 5,
- stride=1,
- padding=2,
- groups=self.point_feat_channels),
- nn.Conv2d(self.point_feat_channels,
- self.point_feat_channels, 1)))
- so_trans.append(
- nn.Sequential(
- nn.Conv2d(
- self.point_feat_channels,
- self.point_feat_channels,
- 5,
- 1,
- 2,
- groups=self.point_feat_channels),
- nn.Conv2d(self.point_feat_channels,
- self.point_feat_channels, 1)))
- self.forder_trans.append(fo_trans)
- self.sorder_trans.append(so_trans)
-
- self.loss_grid = build_loss(loss_grid)
-
- def forward(self, x):
- assert x.shape[-1] == x.shape[-2] == self.roi_feat_size
- # RoI feature transformation, downsample 2x
- x = self.convs(x)
-
- c = self.point_feat_channels
- # first-order fusion
- x_fo = [None for _ in range(self.grid_points)]
- for i, points in enumerate(self.neighbor_points):
- x_fo[i] = x[:, i * c:(i + 1) * c]
- for j, point_idx in enumerate(points):
- x_fo[i] = x_fo[i] + self.forder_trans[i][j](
- x[:, point_idx * c:(point_idx + 1) * c])
-
- # second-order fusion
- x_so = [None for _ in range(self.grid_points)]
- for i, points in enumerate(self.neighbor_points):
- x_so[i] = x[:, i * c:(i + 1) * c]
- for j, point_idx in enumerate(points):
- x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx])
-
- # predicted heatmap with fused features
- x2 = torch.cat(x_so, dim=1)
- x2 = self.deconv1(x2)
- x2 = F.relu(self.norm1(x2), inplace=True)
- heatmap = self.deconv2(x2)
-
- # predicted heatmap with original features (applicable during training)
- if self.training:
- x1 = x
- x1 = self.deconv1(x1)
- x1 = F.relu(self.norm1(x1), inplace=True)
- heatmap_unfused = self.deconv2(x1)
- else:
- heatmap_unfused = heatmap
-
- return dict(fused=heatmap, unfused=heatmap_unfused)
-
- def calc_sub_regions(self):
- """Compute point specific representation regions.
-
- See Grid R-CNN Plus (https://arxiv.org/abs/1906.05688) for details.
- """
- # to make it consistent with the original implementation, half_size
- # is computed as 2 * quarter_size, which is smaller
- half_size = self.whole_map_size // 4 * 2
- sub_regions = []
- for i in range(self.grid_points):
- x_idx = i // self.grid_size
- y_idx = i % self.grid_size
- if x_idx == 0:
- sub_x1 = 0
- elif x_idx == self.grid_size - 1:
- sub_x1 = half_size
- else:
- ratio = x_idx / (self.grid_size - 1) - 0.25
- sub_x1 = max(int(ratio * self.whole_map_size), 0)
-
- if y_idx == 0:
- sub_y1 = 0
- elif y_idx == self.grid_size - 1:
- sub_y1 = half_size
- else:
- ratio = y_idx / (self.grid_size - 1) - 0.25
- sub_y1 = max(int(ratio * self.whole_map_size), 0)
- sub_regions.append(
- (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size))
- return sub_regions
-
- def get_targets(self, sampling_results, rcnn_train_cfg):
- # mix all samples (across images) together.
- pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results],
- dim=0).cpu()
- pos_gt_bboxes = torch.cat(
- [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu()
- assert pos_bboxes.shape == pos_gt_bboxes.shape
-
- # expand pos_bboxes to 2x of original size
- x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
- y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
- x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
- y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
- pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
- pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1)
- pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1)
-
- num_rois = pos_bboxes.shape[0]
- map_size = self.whole_map_size
- # this is not the final target shape
- targets = torch.zeros((num_rois, self.grid_points, map_size, map_size),
- dtype=torch.float)
-
- # pre-compute interpolation factors for all grid points.
- # the first item is the factor of x-dim, and the second is y-dim.
- # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1)
- factors = []
- for j in range(self.grid_points):
- x_idx = j // self.grid_size
- y_idx = j % self.grid_size
- factors.append((1 - x_idx / (self.grid_size - 1),
- 1 - y_idx / (self.grid_size - 1)))
-
- radius = rcnn_train_cfg.pos_radius
- radius2 = radius**2
- for i in range(num_rois):
- # ignore small bboxes
- if (pos_bbox_ws[i] <= self.grid_size
- or pos_bbox_hs[i] <= self.grid_size):
- continue
- # for each grid point, mark a small circle as positive
- for j in range(self.grid_points):
- factor_x, factor_y = factors[j]
- gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + (
- 1 - factor_x) * pos_gt_bboxes[i, 2]
- gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + (
- 1 - factor_y) * pos_gt_bboxes[i, 3]
-
- cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] *
- map_size)
- cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] *
- map_size)
-
- for x in range(cx - radius, cx + radius + 1):
- for y in range(cy - radius, cy + radius + 1):
- if x >= 0 and x < map_size and y >= 0 and y < map_size:
- if (x - cx)**2 + (y - cy)**2 <= radius2:
- targets[i, j, y, x] = 1
- # reduce the target heatmap size by a half
- # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688).
- sub_targets = []
- for i in range(self.grid_points):
- sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i]
- sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2])
- sub_targets = torch.cat(sub_targets, dim=1)
- sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device)
- return sub_targets
-
- def loss(self, grid_pred, grid_targets):
- loss_fused = self.loss_grid(grid_pred['fused'], grid_targets)
- loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets)
- loss_grid = loss_fused + loss_unfused
- return dict(loss_grid=loss_grid)
-
- def get_bboxes(self, det_bboxes, grid_pred, img_metas):
- # TODO: refactoring
- assert det_bboxes.shape[0] == grid_pred.shape[0]
- det_bboxes = det_bboxes.cpu()
- cls_scores = det_bboxes[:, [4]]
- det_bboxes = det_bboxes[:, :4]
- grid_pred = grid_pred.sigmoid().cpu()
-
- R, c, h, w = grid_pred.shape
- half_size = self.whole_map_size // 4 * 2
- assert h == w == half_size
- assert c == self.grid_points
-
- # find the point with max scores in the half-sized heatmap
- grid_pred = grid_pred.view(R * c, h * w)
- pred_scores, pred_position = grid_pred.max(dim=1)
- xs = pred_position % w
- ys = pred_position // w
-
- # get the position in the whole heatmap instead of half-sized heatmap
- for i in range(self.grid_points):
- xs[i::self.grid_points] += self.sub_regions[i][0]
- ys[i::self.grid_points] += self.sub_regions[i][1]
-
- # reshape to (num_rois, grid_points)
- pred_scores, xs, ys = tuple(
- map(lambda x: x.view(R, c), [pred_scores, xs, ys]))
-
- # get expanded pos_bboxes
- widths = (det_bboxes[:, 2] - det_bboxes[:, 0]).unsqueeze(-1)
- heights = (det_bboxes[:, 3] - det_bboxes[:, 1]).unsqueeze(-1)
- x1 = (det_bboxes[:, 0, None] - widths / 2)
- y1 = (det_bboxes[:, 1, None] - heights / 2)
- # map the grid point to the absolute coordinates
- abs_xs = (xs.float() + 0.5) / w * widths + x1
- abs_ys = (ys.float() + 0.5) / h * heights + y1
-
- # get the grid points indices that fall on the bbox boundaries
- x1_inds = [i for i in range(self.grid_size)]
- y1_inds = [i * self.grid_size for i in range(self.grid_size)]
- x2_inds = [
- self.grid_points - self.grid_size + i
- for i in range(self.grid_size)
- ]
- y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)]
-
- # voting of all grid points on some boundary
- bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum(
- dim=1, keepdim=True) / (
- pred_scores[:, x1_inds].sum(dim=1, keepdim=True))
- bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum(
- dim=1, keepdim=True) / (
- pred_scores[:, y1_inds].sum(dim=1, keepdim=True))
- bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum(
- dim=1, keepdim=True) / (
- pred_scores[:, x2_inds].sum(dim=1, keepdim=True))
- bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum(
- dim=1, keepdim=True) / (
- pred_scores[:, y2_inds].sum(dim=1, keepdim=True))
-
- bbox_res = torch.cat(
- [bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2, cls_scores], dim=1)
- bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1])
- bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0])
-
- return bbox_res
|