|
- import re
- import math
- import torch.nn as nn
-
- from copy import deepcopy
-
- from timm.utils import *
- from timm.models.layers.activations import Swish
- from timm.models.layers import CondConv2d, get_condconv_initializer
-
-
- def parse_ksize(ss):
- if ss.isdigit():
- return int(ss)
- else:
- return [int(k) for k in ss.split('.')]
-
-
- def decode_arch_def(
- arch_def,
- depth_multiplier=1.0,
- depth_trunc='ceil',
- experts_multiplier=1):
- arch_args = []
- for stack_idx, block_strings in enumerate(arch_def):
- assert isinstance(block_strings, list)
- stack_args = []
- repeats = []
- for block_str in block_strings:
- assert isinstance(block_str, str)
- ba, rep = decode_block_str(block_str)
- if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
- ba['num_experts'] *= experts_multiplier
- stack_args.append(ba)
- repeats.append(rep)
- arch_args.append(
- scale_stage_depth(
- stack_args,
- repeats,
- depth_multiplier,
- depth_trunc))
- return arch_args
-
-
- def modify_block_args(block_args, kernel_size, exp_ratio):
- block_type = block_args['block_type']
- if block_type == 'cn':
- block_args['kernel_size'] = kernel_size
- elif block_type == 'er':
- block_args['exp_kernel_size'] = kernel_size
- else:
- block_args['dw_kernel_size'] = kernel_size
-
- if block_type == 'ir' or block_type == 'er':
- block_args['exp_ratio'] = exp_ratio
- return block_args
-
-
- def decode_block_str(block_str):
- """ Decode block definition string
- Gets a list of block arg (dicts) through a string notation of arguments.
- E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
- All args can exist in any order with the exception of the leading string which
- is assumed to indicate the block type.
- leading string - block type (
- ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
- r - number of repeat blocks,
- k - kernel size,
- s - strides (1-9),
- e - expansion ratio,
- c - output channels,
- se - squeeze/excitation ratio
- n - activation fn ('re', 'r6', 'hs', or 'sw')
- Args:
- block_str: a string representation of block arguments.
- Returns:
- A list of block args (dicts)
- Raises:
- ValueError: if the string def not properly specified (TODO)
- """
- assert isinstance(block_str, str)
- ops = block_str.split('_')
- block_type = ops[0] # take the block type off the front
- ops = ops[1:]
- options = {}
- noskip = False
- for op in ops:
- # string options being checked on individual basis, combine if they
- # grow
- if op == 'noskip':
- noskip = True
- elif op.startswith('n'):
- # activation fn
- key = op[0]
- v = op[1:]
- if v == 're':
- value = nn.ReLU
- elif v == 'r6':
- value = nn.ReLU6
- elif v == 'sw':
- value = Swish
- else:
- continue
- options[key] = value
- else:
- # all numeric options
- splits = re.split(r'(\d.*)', op)
- if len(splits) >= 2:
- key, value = splits[:2]
- options[key] = value
-
- # if act_layer is None, the model default (passed to model init) will be
- # used
- act_layer = options['n'] if 'n' in options else None
- exp_kernel_size = parse_ksize(options['a']) if 'a' in options else 1
- pw_kernel_size = parse_ksize(options['p']) if 'p' in options else 1
- # FIXME hack to deal with in_chs issue in TPU def
- fake_in_chs = int(options['fc']) if 'fc' in options else 0
-
- num_repeat = int(options['r'])
- # each type of block has different valid arguments, fill accordingly
- if block_type == 'ir':
- block_args = dict(
- block_type=block_type,
- dw_kernel_size=parse_ksize(options['k']),
- exp_kernel_size=exp_kernel_size,
- pw_kernel_size=pw_kernel_size,
- out_chs=int(options['c']),
- exp_ratio=float(options['e']),
- se_ratio=float(options['se']) if 'se' in options else None,
- stride=int(options['s']),
- act_layer=act_layer,
- noskip=noskip,
- )
- if 'cc' in options:
- block_args['num_experts'] = int(options['cc'])
- elif block_type == 'ds' or block_type == 'dsa':
- block_args = dict(
- block_type=block_type,
- dw_kernel_size=parse_ksize(options['k']),
- pw_kernel_size=pw_kernel_size,
- out_chs=int(options['c']),
- se_ratio=float(options['se']) if 'se' in options else None,
- stride=int(options['s']),
- act_layer=act_layer,
- pw_act=block_type == 'dsa',
- noskip=block_type == 'dsa' or noskip,
- )
- elif block_type == 'cn':
- block_args = dict(
- block_type=block_type,
- kernel_size=int(options['k']),
- out_chs=int(options['c']),
- stride=int(options['s']),
- act_layer=act_layer,
- )
- else:
- assert False, 'Unknown block type (%s)' % block_type
-
- return block_args, num_repeat
-
-
- def scale_stage_depth(
- stack_args,
- repeats,
- depth_multiplier=1.0,
- depth_trunc='ceil'):
- """ Per-stage depth scaling
- Scales the block repeats in each stage. This depth scaling impl maintains
- compatibility with the EfficientNet scaling method, while allowing sensible
- scaling for other models that may have multiple block arg definitions in each stage.
- """
-
- # We scale the total repeat count for each stage, there may be multiple
- # block arg defs per stage so we need to sum.
- num_repeat = sum(repeats)
- if depth_trunc == 'round':
- # Truncating to int by rounding allows stages with few repeats to remain
- # proportionally smaller for longer. This is a good choice when stage definitions
- # include single repeat stages that we'd prefer to keep that way as
- # long as possible
- num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
- else:
- # The default for EfficientNet truncates repeats to int via 'ceil'.
- # Any multiplier > 1.0 will result in an increased depth for every
- # stage.
- num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
-
- # Proportionally distribute repeat count scaling to each block definition in the stage.
- # Allocation is done in reverse as it results in the first block being less likely to be scaled.
- # The first block makes less sense to repeat in most of the arch
- # definitions.
- repeats_scaled = []
- for r in repeats[::-1]:
- rs = max(1, round((r / num_repeat * num_repeat_scaled)))
- repeats_scaled.append(rs)
- num_repeat -= r
- num_repeat_scaled -= rs
- repeats_scaled = repeats_scaled[::-1]
-
- # Apply the calculated scaling to each block arg in the stage
- sa_scaled = []
- for ba, rep in zip(stack_args, repeats_scaled):
- sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
- return sa_scaled
-
-
- def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):
- """ Weight initialization as per Tensorflow official implementations.
- Args:
- m (nn.Module): module to init
- n (str): module name
- fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
- Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
- * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
- * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
- """
- if isinstance(m, CondConv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- if fix_group_fanout:
- fan_out //= m.groups
- init_weight_fn = get_condconv_initializer(lambda w: w.data.normal_(
- 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
- init_weight_fn(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- if fix_group_fanout:
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- if n in last_bn:
- m.weight.data.zero_()
- m.bias.data.zero_()
- else:
- m.weight.data.fill_(1.0)
- m.bias.data.zero_()
- m.weight.data.fill_(1.0)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- fan_out = m.weight.size(0) # fan-out
- fan_in = 0
- if 'routing_fn' in n:
- fan_in = m.weight.size(1)
- init_range = 1.0 / math.sqrt(fan_in + fan_out)
- m.weight.data.uniform_(-init_range, init_range)
- m.bias.data.zero_()
-
-
- def efficientnet_init_weights(
- model: nn.Module,
- init_fn=None,
- zero_gamma=False):
- last_bn = []
- if zero_gamma:
- prev_n = ''
- for n, m in model.named_modules():
- if isinstance(m, nn.BatchNorm2d):
- if ''.join(prev_n.split('.')[:-1]) != ''.join(n.split('.')[:-1]):
- last_bn.append(prev_n)
- prev_n = n
- last_bn.append(prev_n)
-
- init_fn = init_fn or init_weight_goog
- for n, m in model.named_modules():
- init_fn(m, n, last_bn=last_bn)
- init_fn(m, n, last_bn=last_bn)
|