|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch.utils.checkpoint as cp
- from mmcv.cnn import ConvModule
- from mmcv.runner import BaseModule
-
- from .se_layer import SELayer
-
-
- class InvertedResidual(BaseModule):
- """Inverted Residual Block.
-
- Args:
- in_channels (int): The input channels of this Module.
- out_channels (int): The output channels of this Module.
- mid_channels (int): The input channels of the depthwise convolution.
- kernel_size (int): The kernel size of the depthwise convolution.
- Default: 3.
- stride (int): The stride of the depthwise convolution. Default: 1.
- se_cfg (dict): Config dict for se layer. Default: None, which means no
- se layer.
- with_expand_conv (bool): Use expand conv or not. If set False,
- mid_channels must be the same with in_channels.
- Default: True.
- conv_cfg (dict): Config dict for convolution layer. Default: None,
- which means using conv2d.
- norm_cfg (dict): Config dict for normalization layer.
- Default: dict(type='BN').
- act_cfg (dict): Config dict for activation layer.
- Default: dict(type='ReLU').
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
- memory while slowing down the training speed. Default: False.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Default: None
-
- Returns:
- Tensor: The output tensor.
- """
-
- def __init__(self,
- in_channels,
- out_channels,
- mid_channels,
- kernel_size=3,
- stride=1,
- se_cfg=None,
- with_expand_conv=True,
- conv_cfg=None,
- norm_cfg=dict(type='BN'),
- act_cfg=dict(type='ReLU'),
- with_cp=False,
- init_cfg=None):
- super(InvertedResidual, self).__init__(init_cfg)
- self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
- assert stride in [1, 2], f'stride must in [1, 2]. ' \
- f'But received {stride}.'
- self.with_cp = with_cp
- self.with_se = se_cfg is not None
- self.with_expand_conv = with_expand_conv
-
- if self.with_se:
- assert isinstance(se_cfg, dict)
- if not self.with_expand_conv:
- assert mid_channels == in_channels
-
- if self.with_expand_conv:
- self.expand_conv = ConvModule(
- in_channels=in_channels,
- out_channels=mid_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
- self.depthwise_conv = ConvModule(
- in_channels=mid_channels,
- out_channels=mid_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=kernel_size // 2,
- groups=mid_channels,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg)
-
- if self.with_se:
- self.se = SELayer(**se_cfg)
-
- self.linear_conv = ConvModule(
- in_channels=mid_channels,
- out_channels=out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=None)
-
- def forward(self, x):
-
- def _inner_forward(x):
- out = x
-
- if self.with_expand_conv:
- out = self.expand_conv(out)
-
- out = self.depthwise_conv(out)
-
- if self.with_se:
- out = self.se(out)
-
- out = self.linear_conv(out)
-
- if self.with_res_shortcut:
- return x + out
- else:
- return out
-
- if self.with_cp and x.requires_grad:
- out = cp.checkpoint(_inner_forward, x)
- else:
- out = _inner_forward(x)
-
- return out
|