|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule, caffe2_xavier_init
- from mmcv.ops.merge_cells import ConcatCell
- from mmcv.runner import BaseModule
-
- from ..builder import NECKS
-
-
- @NECKS.register_module()
- class NASFCOS_FPN(BaseModule):
- """FPN structure in NASFPN.
-
- Implementation of paper `NAS-FCOS: Fast Neural Architecture Search for
- Object Detection <https://arxiv.org/abs/1906.04423>`_
-
- Args:
- in_channels (List[int]): Number of input channels per scale.
- out_channels (int): Number of output channels (used at each scale)
- num_outs (int): Number of output scales.
- start_level (int): Index of the start input backbone level used to
- build the feature pyramid. Default: 0.
- end_level (int): Index of the end input backbone level (exclusive) to
- build the feature pyramid. Default: -1, which means the last level.
- add_extra_convs (bool): It decides whether to add conv
- layers on top of the original feature maps. Default to False.
- If True, its actual mode is specified by `extra_convs_on_inputs`.
- conv_cfg (dict): dictionary to construct and config conv layer.
- norm_cfg (dict): dictionary to construct and config norm layer.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Default: None
- """
-
- def __init__(self,
- in_channels,
- out_channels,
- num_outs,
- start_level=1,
- end_level=-1,
- add_extra_convs=False,
- conv_cfg=None,
- norm_cfg=None,
- init_cfg=None):
- assert init_cfg is None, 'To prevent abnormal initialization ' \
- 'behavior, init_cfg is not allowed to be set'
- super(NASFCOS_FPN, self).__init__(init_cfg)
- assert isinstance(in_channels, list)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.num_ins = len(in_channels)
- self.num_outs = num_outs
- self.norm_cfg = norm_cfg
- self.conv_cfg = conv_cfg
-
- if end_level == -1:
- self.backbone_end_level = self.num_ins
- assert num_outs >= self.num_ins - start_level
- else:
- self.backbone_end_level = end_level
- assert end_level <= len(in_channels)
- assert num_outs == end_level - start_level
- self.start_level = start_level
- self.end_level = end_level
- self.add_extra_convs = add_extra_convs
-
- self.adapt_convs = nn.ModuleList()
- for i in range(self.start_level, self.backbone_end_level):
- adapt_conv = ConvModule(
- in_channels[i],
- out_channels,
- 1,
- stride=1,
- padding=0,
- bias=False,
- norm_cfg=dict(type='BN'),
- act_cfg=dict(type='ReLU', inplace=False))
- self.adapt_convs.append(adapt_conv)
-
- # C2 is omitted according to the paper
- extra_levels = num_outs - self.backbone_end_level + self.start_level
-
- def build_concat_cell(with_input1_conv, with_input2_conv):
- cell_conv_cfg = dict(
- kernel_size=1, padding=0, bias=False, groups=out_channels)
- return ConcatCell(
- in_channels=out_channels,
- out_channels=out_channels,
- with_out_conv=True,
- out_conv_cfg=cell_conv_cfg,
- out_norm_cfg=dict(type='BN'),
- out_conv_order=('norm', 'act', 'conv'),
- with_input1_conv=with_input1_conv,
- with_input2_conv=with_input2_conv,
- input_conv_cfg=conv_cfg,
- input_norm_cfg=norm_cfg,
- upsample_mode='nearest')
-
- # Denote c3=f0, c4=f1, c5=f2 for convince
- self.fpn = nn.ModuleDict()
- self.fpn['c22_1'] = build_concat_cell(True, True)
- self.fpn['c22_2'] = build_concat_cell(True, True)
- self.fpn['c32'] = build_concat_cell(True, False)
- self.fpn['c02'] = build_concat_cell(True, False)
- self.fpn['c42'] = build_concat_cell(True, True)
- self.fpn['c36'] = build_concat_cell(True, True)
- self.fpn['c61'] = build_concat_cell(True, True) # f9
- self.extra_downsamples = nn.ModuleList()
- for i in range(extra_levels):
- extra_act_cfg = None if i == 0 \
- else dict(type='ReLU', inplace=False)
- self.extra_downsamples.append(
- ConvModule(
- out_channels,
- out_channels,
- 3,
- stride=2,
- padding=1,
- act_cfg=extra_act_cfg,
- order=('act', 'norm', 'conv')))
-
- def forward(self, inputs):
- """Forward function."""
- feats = [
- adapt_conv(inputs[i + self.start_level])
- for i, adapt_conv in enumerate(self.adapt_convs)
- ]
-
- for (i, module_name) in enumerate(self.fpn):
- idx_1, idx_2 = int(module_name[1]), int(module_name[2])
- res = self.fpn[module_name](feats[idx_1], feats[idx_2])
- feats.append(res)
-
- ret = []
- for (idx, input_idx) in zip([9, 8, 7], [1, 2, 3]): # add P3, P4, P5
- feats1, feats2 = feats[idx], feats[5]
- feats2_resize = F.interpolate(
- feats2,
- size=feats1.size()[2:],
- mode='bilinear',
- align_corners=False)
-
- feats_sum = feats1 + feats2_resize
- ret.append(
- F.interpolate(
- feats_sum,
- size=inputs[input_idx].size()[2:],
- mode='bilinear',
- align_corners=False))
-
- for submodule in self.extra_downsamples:
- ret.append(submodule(ret[-1]))
-
- return tuple(ret)
-
- def init_weights(self):
- """Initialize the weights of module."""
- super(NASFCOS_FPN, self).init_weights()
- for module in self.fpn.values():
- if hasattr(module, 'conv_out'):
- caffe2_xavier_init(module.out_conv.conv)
-
- for modules in [
- self.adapt_convs.modules(),
- self.extra_downsamples.modules()
- ]:
- for module in modules:
- if isinstance(module, nn.Conv2d):
- caffe2_xavier_init(module)
|