|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import constant_init, xavier_init
- from mmcv.runner import BaseModule, ModuleList
-
- from ..builder import NECKS, build_backbone
- from .fpn import FPN
-
-
- class ASPP(BaseModule):
- """ASPP (Atrous Spatial Pyramid Pooling)
-
- This is an implementation of the ASPP module used in DetectoRS
- (https://arxiv.org/pdf/2006.02334.pdf)
-
- Args:
- in_channels (int): Number of input channels.
- out_channels (int): Number of channels produced by this module
- dilations (tuple[int]): Dilations of the four branches.
- Default: (1, 3, 6, 1)
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
-
- def __init__(self,
- in_channels,
- out_channels,
- dilations=(1, 3, 6, 1),
- init_cfg=dict(type='Kaiming', layer='Conv2d')):
- super().__init__(init_cfg)
- assert dilations[-1] == 1
- self.aspp = nn.ModuleList()
- for dilation in dilations:
- kernel_size = 3 if dilation > 1 else 1
- padding = dilation if dilation > 1 else 0
- conv = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=1,
- dilation=dilation,
- padding=padding,
- bias=True)
- self.aspp.append(conv)
- self.gap = nn.AdaptiveAvgPool2d(1)
-
- def forward(self, x):
- avg_x = self.gap(x)
- out = []
- for aspp_idx in range(len(self.aspp)):
- inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x
- out.append(F.relu_(self.aspp[aspp_idx](inp)))
- out[-1] = out[-1].expand_as(out[-2])
- out = torch.cat(out, dim=1)
- return out
-
-
- @NECKS.register_module()
- class RFP(FPN):
- """RFP (Recursive Feature Pyramid)
-
- This is an implementation of RFP in `DetectoRS
- <https://arxiv.org/pdf/2006.02334.pdf>`_. Different from standard FPN, the
- input of RFP should be multi level features along with origin input image
- of backbone.
-
- Args:
- rfp_steps (int): Number of unrolled steps of RFP.
- rfp_backbone (dict): Configuration of the backbone for RFP.
- aspp_out_channels (int): Number of output channels of ASPP module.
- aspp_dilations (tuple[int]): Dilation rates of four branches.
- Default: (1, 3, 6, 1)
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Default: None
- """
-
- def __init__(self,
- rfp_steps,
- rfp_backbone,
- aspp_out_channels,
- aspp_dilations=(1, 3, 6, 1),
- init_cfg=None,
- **kwargs):
- assert init_cfg is None, 'To prevent abnormal initialization ' \
- 'behavior, init_cfg is not allowed to be set'
- super().__init__(init_cfg=init_cfg, **kwargs)
- self.rfp_steps = rfp_steps
- # Be careful! Pretrained weights cannot be loaded when use
- # nn.ModuleList
- self.rfp_modules = ModuleList()
- for rfp_idx in range(1, rfp_steps):
- rfp_module = build_backbone(rfp_backbone)
- self.rfp_modules.append(rfp_module)
- self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels,
- aspp_dilations)
- self.rfp_weight = nn.Conv2d(
- self.out_channels,
- 1,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=True)
-
- def init_weights(self):
- # Avoid using super().init_weights(), which may alter the default
- # initialization of the modules in self.rfp_modules that have missing
- # keys in the pretrained checkpoint.
- for convs in [self.lateral_convs, self.fpn_convs]:
- for m in convs.modules():
- if isinstance(m, nn.Conv2d):
- xavier_init(m, distribution='uniform')
- for rfp_idx in range(self.rfp_steps - 1):
- self.rfp_modules[rfp_idx].init_weights()
- constant_init(self.rfp_weight, 0)
-
- def forward(self, inputs):
- inputs = list(inputs)
- assert len(inputs) == len(self.in_channels) + 1 # +1 for input image
- img = inputs.pop(0)
- # FPN forward
- x = super().forward(tuple(inputs))
- for rfp_idx in range(self.rfp_steps - 1):
- rfp_feats = [x[0]] + list(
- self.rfp_aspp(x[i]) for i in range(1, len(x)))
- x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats)
- # FPN forward
- x_idx = super().forward(x_idx)
- x_new = []
- for ft_idx in range(len(x_idx)):
- add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx]))
- x_new.append(add_weight * x_idx[ft_idx] +
- (1 - add_weight) * x[ft_idx])
- x = x_new
- return x
|