|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch.nn as nn
- from mmcv.cnn import ConvModule
- from mmcv.runner import BaseModule
-
- from ..builder import NECKS
-
-
- @NECKS.register_module()
- class ChannelMapper(BaseModule):
- r"""Channel Mapper to reduce/increase channels of backbone features.
-
- This is used to reduce/increase channels of backbone features.
-
- Args:
- in_channels (List[int]): Number of input channels per scale.
- out_channels (int): Number of output channels (used at each scale).
- kernel_size (int, optional): kernel_size for reducing channels (used
- at each scale). Default: 3.
- conv_cfg (dict, optional): Config dict for convolution layer.
- Default: None.
- norm_cfg (dict, optional): Config dict for normalization layer.
- Default: None.
- act_cfg (dict, optional): Config dict for activation layer in
- ConvModule. Default: dict(type='ReLU').
- num_outs (int, optional): Number of output feature maps. There
- would be extra_convs when num_outs larger than the length
- of in_channels.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Example:
- >>> import torch
- >>> in_channels = [2, 3, 5, 7]
- >>> scales = [340, 170, 84, 43]
- >>> inputs = [torch.rand(1, c, s, s)
- ... for c, s in zip(in_channels, scales)]
- >>> self = ChannelMapper(in_channels, 11, 3).eval()
- >>> outputs = self.forward(inputs)
- >>> for i in range(len(outputs)):
- ... print(f'outputs[{i}].shape = {outputs[i].shape}')
- outputs[0].shape = torch.Size([1, 11, 340, 340])
- outputs[1].shape = torch.Size([1, 11, 170, 170])
- outputs[2].shape = torch.Size([1, 11, 84, 84])
- outputs[3].shape = torch.Size([1, 11, 43, 43])
- """
-
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- conv_cfg=None,
- norm_cfg=None,
- act_cfg=dict(type='ReLU'),
- num_outs=None,
- init_cfg=dict(
- type='Xavier', layer='Conv2d', distribution='uniform')):
- super(ChannelMapper, self).__init__(init_cfg)
- assert isinstance(in_channels, list)
- self.extra_convs = None
- if num_outs is None:
- num_outs = len(in_channels)
- self.convs = nn.ModuleList()
- for in_channel in in_channels:
- self.convs.append(
- ConvModule(
- in_channel,
- out_channels,
- kernel_size,
- padding=(kernel_size - 1) // 2,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg))
- if num_outs > len(in_channels):
- self.extra_convs = nn.ModuleList()
- for i in range(len(in_channels), num_outs):
- if i == len(in_channels):
- in_channel = in_channels[-1]
- else:
- in_channel = out_channels
- self.extra_convs.append(
- ConvModule(
- in_channel,
- out_channels,
- 3,
- stride=2,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg))
-
- def forward(self, inputs):
- """Forward function."""
- assert len(inputs) == len(self.convs)
- outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
- if self.extra_convs:
- for i in range(len(self.extra_convs)):
- if i == 0:
- outs.append(self.extra_convs[0](inputs[-1]))
- else:
- outs.append(self.extra_convs[i](outs[-1]))
- return tuple(outs)
|