|
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
-
- import torch.nn as nn
- from mmcv.runner import BaseModule, auto_fp16
-
- from mmdet.models.backbones import ResNet
- from mmdet.models.builder import SHARED_HEADS
- from mmdet.models.utils import ResLayer as _ResLayer
-
-
- @SHARED_HEADS.register_module()
- class ResLayer(BaseModule):
-
- def __init__(self,
- depth,
- stage=3,
- stride=2,
- dilation=1,
- style='pytorch',
- norm_cfg=dict(type='BN', requires_grad=True),
- norm_eval=True,
- with_cp=False,
- dcn=None,
- pretrained=None,
- init_cfg=None):
- super(ResLayer, self).__init__(init_cfg)
-
- self.norm_eval = norm_eval
- self.norm_cfg = norm_cfg
- self.stage = stage
- self.fp16_enabled = False
- block, stage_blocks = ResNet.arch_settings[depth]
- stage_block = stage_blocks[stage]
- planes = 64 * 2**stage
- inplanes = 64 * 2**(stage - 1) * block.expansion
-
- res_layer = _ResLayer(
- block,
- inplanes,
- planes,
- stage_block,
- stride=stride,
- dilation=dilation,
- style=style,
- with_cp=with_cp,
- norm_cfg=self.norm_cfg,
- dcn=dcn)
- self.add_module(f'layer{stage + 1}', res_layer)
-
- assert not (init_cfg and pretrained), \
- 'init_cfg and pretrained cannot be specified at the same time'
- if isinstance(pretrained, str):
- warnings.warn('DeprecationWarning: pretrained is a deprecated, '
- 'please use "init_cfg" instead')
- self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
- elif pretrained is None:
- if init_cfg is None:
- self.init_cfg = [
- dict(type='Kaiming', layer='Conv2d'),
- dict(
- type='Constant',
- val=1,
- layer=['_BatchNorm', 'GroupNorm'])
- ]
- else:
- raise TypeError('pretrained must be a str or None')
-
- @auto_fp16()
- def forward(self, x):
- res_layer = getattr(self, f'layer{self.stage + 1}')
- out = res_layer(x)
- return out
-
- def train(self, mode=True):
- super(ResLayer, self).train(mode)
- if self.norm_eval:
- for m in self.modules():
- if isinstance(m, nn.BatchNorm2d):
- m.eval()
|