|
- # pylint: disable=E0401
- # pylint: disable=W0201
- """
- MindSpore implementation of `CeiT`.
- Refer to "Incorporating Convolution Designs into Visual Transformers"
- """
- import math
- import mindspore as ms
- from mindspore import nn
- from mindspore.common.initializer import initializer, TruncatedNormal
-
- from model.layers import DropPath, to_2tuple
- from model.helper import load_pretrained
- from model.registry import register_model
-
- __all__ = [
- 'ceit_tiny_patch16_224', 'ceit_small_patch16_224', 'ceit_base_patch16_224',
- 'ceit_tiny_patch16_384', 'ceit_small_patch16_384',
- ]
-
-
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
- # 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
- 'first_conv': 'patch_embed.proj', 'classifier': 'head',
- **kwargs
- }
-
-
- class Image2Tokens(nn.Cell):
- """ image to tokens """
- def __init__(self, in_chans=3, out_chans=64, kernel_size=7, stride=2):
- super().__init__()
- self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_size, stride=stride,
- pad_mode='pad', padding=kernel_size // 2, has_bias=False)
- self.bn = nn.BatchNorm2d(out_chans)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='pad', padding=1)
-
- def construct(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.maxpool(x)
- return x
-
-
- class Mlp(nn.Cell):
- """ mlp """
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Dense(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Dense(hidden_features, out_features)
- self.drop = nn.Dropout(p=drop)
-
- def construct(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
- class LocallyEnhancedFeedForward(nn.Cell):
- """ locally enhanced feed forward """
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
- kernel_size=3, with_bn=True):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- # pointwise
- self.conv1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
- # depthwise
- self.conv2 = nn.Conv2d(
- hidden_features, hidden_features, kernel_size=kernel_size, stride=1,
- pad_mode='pad', padding=(kernel_size - 1) // 2, group=hidden_features
- )
- # pointwise
- self.conv3 = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1)
- self.act = act_layer()
- # self.drop = nn.Dropout(drop)
-
- self.with_bn = with_bn
- if self.with_bn:
- self.bn1 = nn.BatchNorm2d(hidden_features)
- self.bn2 = nn.BatchNorm2d(hidden_features)
- self.bn3 = nn.BatchNorm2d(out_features)
-
- def construct(self, x):
- b, n, k = x.shape
- cls_token, tokens = ms.ops.split(x, [1, n - 1], axis=1)
- x = tokens.reshape(b, int(math.sqrt(n - 1)), int(math.sqrt(n - 1)), k).permute(0, 3, 1, 2)
- if self.with_bn:
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.act(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.act(x)
- x = self.conv3(x)
- x = self.bn3(x)
- else:
- x = self.conv1(x)
- x = self.act(x)
- x = self.conv2(x)
- x = self.act(x)
- x = self.conv3(x)
-
- tokens = x.flatten(start_dim=2).permute(0, 2, 1)
- out = ms.ops.cat((cls_token, tokens), axis=1)
- return out
-
-
- class Attention(nn.Cell):
- """ self attention """
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
- self.scale = qk_scale or head_dim ** -0.5
-
- self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
- self.attn_drop = nn.Dropout(p=attn_drop)
- self.proj = nn.Dense(dim, dim)
- self.proj_drop = nn.Dropout(p=proj_drop)
- self.attention_map = None
-
- def construct(self, x):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
- attn = ms.ops.softmax(attn, axis=-1)
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
- class AttentionLCA(Attention):
- """ attention lca """
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
- super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
- self.dim = dim
- self.qkv_bias = qkv_bias
-
- def construct(self, x):
- q_weight = self.qkv.weight[:self.dim, :]
- q_bias = None if not self.qkv_bias else self.qkv.bias[:self.dim]
- kv_weight = self.qkv.weight[self.dim:, :]
- kv_bias = None if not self.qkv_bias else self.qkv.bias[self.dim:]
-
- B, N, C = x.shape
- _, last_token = ms.ops.split(x, [N - 1, 1], axis=1)
-
- q = (last_token @ q_weight.T + q_bias) \
- .reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
- kv = (x @ kv_weight.T + kv_bias) \
- .reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- k, v = kv[0], kv[1]
-
- attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
- attn = ms.ops.softmax(attn, axis=-1)
- # self.attention_map = attn
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, 1, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
- class Block(nn.Cell):
- """ lca blocks """
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3, with_bn=True,
- feedforward_type='leff'):
- super().__init__()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer((dim, ))
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.norm1 = norm_layer((dim, ))
- self.feedforward_type = feedforward_type
-
- if feedforward_type == 'leff':
- self.attn = Attention(
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
- self.leff = LocallyEnhancedFeedForward(
- in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
- kernel_size=kernel_size, with_bn=with_bn,
- )
- else: # LCA
- self.attn = AttentionLCA(
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
- self.feedforward = Mlp(
- in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
- )
-
- def construct(self, x):
- if self.feedforward_type == 'leff':
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.leff(self.norm2(x)))
- return x, x[:, 0]
- # LCA
- _, last_token = ms.ops.split(x, [x.shape[1] - 1, 1], axis=1)
- x = last_token + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.feedforward(self.norm2(x)))
- return x
-
-
- class HybridEmbed(nn.Cell):
- """ CNN Feature Map Embedding
- Extract feature map from CNN, flatten, project to embedding dim.
- """
-
- def __init__(self, backbone, img_size=224, patch_size=16, feature_size=None, in_chans=3, embed_dim=768):
- super().__init__()
- assert isinstance(backbone, nn.Cell)
- img_size = to_2tuple(img_size)
- self.img_size = img_size
- self.backbone = backbone
- if feature_size is None:
- # map for all networks, the feature metadata has reliable channel and stride info, but using
- # stride to calc feature dim requires info about padding of each stage that isn't captured.
- training = backbone.training
- if training:
- backbone.eval()
- o = self.backbone(ms.ops.zeros((1, in_chans, img_size[0], img_size[1])))
- if isinstance(o, (list, tuple)):
- o = o[-1] # last feature if backbone outputs list/tuple of features
- feature_size = o.shape[-2:]
- feature_dim = o.shape[1]
- # backbone.train(training)
- else:
- feature_size = to_2tuple(feature_size)
- feature_dim = self.backbone.feature_info.channels()[-1]
- self.num_patches = (feature_size[0] // patch_size) * (feature_size[1] // patch_size)
- self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
-
- def construct(self, x):
- x = self.backbone(x)
- if isinstance(x, (list, tuple)):
- x = x[-1] # last feature if backbone outputs list/tuple of features
- x = self.proj(x).flatten(start_dim=2).transpose(0, 2, 1)
- return x
-
-
- class CeIT(nn.Cell):
- """ CeIT """
- def __init__(self,
- img_size=224,
- patch_size=16,
- in_chans=3,
- num_classes=1000,
- embed_dim=768,
- depth=12,
- num_heads=12,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- hybrid_backbone=None,
- norm_layer=nn.LayerNorm,
- leff_local_size=3,
- leff_with_bn=True):
- """
- args:
- - img_size (:obj:`int`): input image size
- - patch_size (:obj:`int`): patch size
- - in_chans (:obj:`int`): input channels
- - num_classes (:obj:`int`): number of classes
- - embed_dim (:obj:`int`): embedding dimensions for tokens
- - depth (:obj:`int`): depth of encoder
- - num_heads (:obj:`int`): number of heads in multi-head self-attention
- - mlp_ratio (:obj:`float`): expand ratio in feedforward
- - qkv_bias (:obj:`bool`): whether to add bias for mlp of qkv
- - qk_scale (:obj:`float`): scale ratio for qk, default is head_dim ** -0.5
- - drop_rate (:obj:`float`): dropout rate in feedforward module after linear operation
- and projection drop rate in attention
- - attn_drop_rate (:obj:`float`): dropout rate for attention
- - drop_path_rate (:obj:`float`): drop_path rate after attention
- - hybrid_backbone (:obj:`nn.Module`): backbone e.g. resnet
- - norm_layer (:obj:`nn.Module`): normalization type
- - leff_local_size (:obj:`int`): kernel size in LocallyEnhancedFeedForward
- - leff_with_bn (:obj:`bool`): whether add bn in LocallyEnhancedFeedForward
- """
- super().__init__()
- self.num_classes = num_classes
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
-
- self.i2t = HybridEmbed(
- hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
- num_patches = self.i2t.num_patches
-
- self.cls_token = ms.Parameter(ms.ops.zeros((1, 1, embed_dim)))
- self.pos_embed = ms.Parameter(ms.ops.zeros((1, num_patches + 1, embed_dim)))
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- dpr = list(ms.ops.linspace(0, drop_path_rate, depth)) # stochastic depth decay rule
- self.blocks = nn.SequentialCell([
- Block(
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
- kernel_size=leff_local_size, with_bn=leff_with_bn)
- for i in range(depth)])
-
- # without droppath
- self.lca = Block(
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., norm_layer=norm_layer,
- feedforward_type='lca'
- )
- self.pos_layer_embed = ms.Parameter(ms.ops.zeros((1, depth, embed_dim)))
-
- self.norm = norm_layer((embed_dim, ))
-
- # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
- # Classifier head
- self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
-
- self.pos_embed = initializer(TruncatedNormal(sigma=.02), self.pos_embed.shape, self.pos_embed.dtype)
- self.cls_token = initializer(TruncatedNormal(sigma=.02), self.cls_token.shape, self.cls_token.dtype)
- self.apply(self._init_weights)
-
- def _init_weights(self, cell):
- if isinstance(cell, nn.Dense):
- cell.weight.set_data(initializer(TruncatedNormal(sigma=.02), cell.weight.shape, cell.weight.dtype))
- if cell.bias is not None:
- cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
- elif isinstance(cell, nn.LayerNorm):
- cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype))
- cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))
-
- def get_classifier(self):
- """ get classifier """
- return self.head
-
- def reset_classifier(self, num_classes):
- """ reset classifier """
- self.num_classes = num_classes
- self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
-
- def construct_features(self, x):
- """ construct features """
- B = x.shape[0]
- x = self.i2t(x)
-
- cls_tokens = self.cls_token
- x = ms.ops.cat((cls_tokens, x), axis=1)
- x = x + self.pos_embed
- x = self.pos_drop(x)
-
- cls_token_list = []
- for blk in self.blocks:
- x, curr_cls_token = blk(x)
- cls_token_list.append(curr_cls_token)
-
- all_cls_token = ms.ops.stack(cls_token_list, axis=1) # B*D*K
- all_cls_token = all_cls_token + self.pos_layer_embed
- # attention over cls tokens
- last_cls_token = self.lca(all_cls_token)
- last_cls_token = self.norm(last_cls_token)
-
- return last_cls_token.view(B, -1)
-
- def construct(self, x):
- x = self.construct_features(x)
- x = self.head(x)
- return x
-
-
- @register_model
- def ceit_tiny_patch16_224(pretrained=False, **kwargs):
- """
- convolutional + pooling stem
- local enhanced feedforward
- attention over cls_tokens
- """
- default_cfg = _cfg(**kwargs)
-
- i2t = Image2Tokens()
- model = CeIT(
- hybrid_backbone=i2t,
- patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
- norm_layer=nn.LayerNorm, **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
- return model
-
-
- @register_model
- def ceit_small_patch16_224(pretrained=False, **kwargs):
- """
- convolutional + pooling stem
- local enhanced feedforward
- attention over cls_tokens
- """
- default_cfg = _cfg()
- i2t = Image2Tokens()
- model = CeIT(
- hybrid_backbone=i2t,
- patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
- norm_layer=nn.LayerNorm, **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def ceit_base_patch16_224(pretrained=False, **kwargs):
- """
- convolutional + pooling stem
- local enhanced feedforward
- attention over cls_tokens
- """
- default_cfg = _cfg()
- i2t = Image2Tokens()
- model = CeIT(
- hybrid_backbone=i2t,
- patch_size=4, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
- norm_layer=nn.LayerNorm, **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def ceit_tiny_patch16_384(pretrained=False, **kwargs):
- """
- convolutional + pooling stem
- local enhanced feedforward
- attention over cls_tokens
- """
- default_cfg = _cfg()
- i2t = Image2Tokens()
- model = CeIT(
- hybrid_backbone=i2t, img_size=384,
- patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
- norm_layer=nn.LayerNorm, **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def ceit_small_patch16_384(pretrained=False, **kwargs):
- """
- convolutional + pooling stem
- local enhanced feedforward
- attention over cls_tokens
- """
- default_cfg = _cfg()
- i2t = Image2Tokens()
- model = CeIT(
- hybrid_backbone=i2t, img_size=384,
- patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
- norm_layer=nn.LayerNorm, **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- if __name__ == '__main__':
- dummy_input = ms.ops.randn((1, 3, 224, 224))
- ceit = ceit_tiny_patch16_224()
- output = ceit(dummy_input)
- print(output.shape)
|