|
- from ...utils.util import *
-
- from collections import OrderedDict
- from timm.models.efficientnet_blocks import *
-
-
- class ChildNetBuilder:
- def __init__(
- self,
- channel_multiplier=1.0,
- channel_divisor=8,
- channel_min=None,
- output_stride=32,
- pad_type='',
- act_layer=None,
- se_kwargs=None,
- norm_layer=nn.BatchNorm2d,
- norm_kwargs=None,
- drop_path_rate=0.,
- feature_location='',
- verbose=False,
- logger=None):
- self.channel_multiplier = channel_multiplier
- self.channel_divisor = channel_divisor
- self.channel_min = channel_min
- self.output_stride = output_stride
- self.pad_type = pad_type
- self.act_layer = act_layer
- self.se_kwargs = se_kwargs
- self.norm_layer = norm_layer
- self.norm_kwargs = norm_kwargs
- self.drop_path_rate = drop_path_rate
- self.feature_location = feature_location
- assert feature_location in ('pre_pwl', 'post_exp', '')
- self.verbose = verbose
- self.in_chs = None
- self.features = OrderedDict()
- self.logger = logger
-
- def _round_channels(self, chs):
- return round_channels(
- chs,
- self.channel_multiplier,
- self.channel_divisor,
- self.channel_min)
-
- def _make_block(self, ba, block_idx, block_count):
- drop_path_rate = self.drop_path_rate * block_idx / block_count
- bt = ba.pop('block_type')
- ba['in_chs'] = self.in_chs
- ba['out_chs'] = self._round_channels(ba['out_chs'])
- if 'fake_in_chs' in ba and ba['fake_in_chs']:
- ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
- ba['norm_layer'] = self.norm_layer
- ba['norm_kwargs'] = self.norm_kwargs
- ba['pad_type'] = self.pad_type
- # block act fn overrides the model default
- ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
- assert ba['act_layer'] is not None
- if bt == 'ir':
- ba['drop_path_rate'] = drop_path_rate
- ba['se_kwargs'] = self.se_kwargs
- if self.verbose:
- self.logger.info(
- ' InvertedResidual {}, Args: {}'.format(
- block_idx, str(ba)))
- block = InvertedResidual(**ba)
- elif bt == 'ds' or bt == 'dsa':
- ba['drop_path_rate'] = drop_path_rate
- ba['se_kwargs'] = self.se_kwargs
- if self.verbose:
- self.logger.info(
- ' DepthwiseSeparable {}, Args: {}'.format(
- block_idx, str(ba)))
- block = DepthwiseSeparableConv(**ba)
- elif bt == 'cn':
- if self.verbose:
- self.logger.info(
- ' ConvBnAct {}, Args: {}'.format(
- block_idx, str(ba)))
- block = ConvBnAct(**ba)
- else:
- assert False, 'Uknkown block type (%s) while building model.' % bt
- self.in_chs = ba['out_chs'] # update in_chs for arg of next block
-
- return block
-
- def __call__(self, in_chs, model_block_args):
- """ Build the blocks
- Args:
- in_chs: Number of input-channels passed to first block
- model_block_args: A list of lists, outer list defines stages, inner
- list contains strings defining block configuration(s)
- Return:
- List of block stacks (each stack wrapped in nn.Sequential)
- """
- if self.verbose:
- self.logger.info(
- 'Building model trunk with %d stages...' %
- len(model_block_args))
- self.in_chs = in_chs
- total_block_count = sum([len(x) for x in model_block_args])
- total_block_idx = 0
- current_stride = 2
- current_dilation = 1
- feature_idx = 0
- stages = []
- # outer list of block_args defines the stacks ('stages' by some
- # conventions)
- for stage_idx, stage_block_args in enumerate(model_block_args):
- last_stack = stage_idx == (len(model_block_args) - 1)
- if self.verbose:
- self.logger.info('Stack: {}'.format(stage_idx))
- assert isinstance(stage_block_args, list)
-
- blocks = []
- # each stack (stage) contains a list of block arguments
- for block_idx, block_args in enumerate(stage_block_args):
- last_block = block_idx == (len(stage_block_args) - 1)
- extract_features = '' # No features extracted
- if self.verbose:
- self.logger.info(' Block: {}'.format(block_idx))
-
- # Sort out stride, dilation, and feature extraction details
- assert block_args['stride'] in (1, 2)
- if block_idx >= 1:
- # only the first block in any stack can have a stride > 1
- block_args['stride'] = 1
-
- do_extract = False
- if self.feature_location == 'pre_pwl':
- if last_block:
- next_stage_idx = stage_idx + 1
- if next_stage_idx >= len(model_block_args):
- do_extract = True
- else:
- do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
- elif self.feature_location == 'post_exp':
- if block_args['stride'] > 1 or (last_stack and last_block):
- do_extract = True
- if do_extract:
- extract_features = self.feature_location
-
- next_dilation = current_dilation
- if block_args['stride'] > 1:
- next_output_stride = current_stride * block_args['stride']
- if next_output_stride > self.output_stride:
- next_dilation = current_dilation * block_args['stride']
- block_args['stride'] = 1
- if self.verbose:
- self.logger.info(
- ' Converting stride to dilation to maintain output_stride=={}'.format(
- self.output_stride))
- else:
- current_stride = next_output_stride
- block_args['dilation'] = current_dilation
- if next_dilation != current_dilation:
- current_dilation = next_dilation
-
- # create the block
- block = self._make_block(
- block_args, total_block_idx, total_block_count)
- blocks.append(block)
-
- # stash feature module name and channel info for model feature
- # extraction
- if extract_features:
- feature_module = block.feature_module(extract_features)
- if feature_module:
- feature_module = 'blocks.{}.{}.'.format(
- stage_idx, block_idx) + feature_module
- feature_channels = block.feature_channels(extract_features)
- self.features[feature_idx] = dict(
- name=feature_module,
- num_chs=feature_channels
- )
- feature_idx += 1
-
- # incr global block idx (across all stacks)
- total_block_idx += 1
- stages.append(nn.Sequential(*blocks))
- return stages
|