| @@ -1,487 +0,0 @@ | |||
| # pylint: disable=E0401 | |||
| # pylint: disable=W0201 | |||
| """ | |||
| MindSpore implementation of `CeiT`. | |||
| Refer to "Incorporating Convolution Designs into Visual Transformers" | |||
| """ | |||
| import math | |||
| import mindspore as ms | |||
| from mindspore import nn | |||
| from mindspore.common.initializer import initializer, TruncatedNormal | |||
| from model.layers import DropPath, to_2tuple | |||
| from model.helper import load_pretrained | |||
| from model.registry import register_model | |||
| __all__ = [ | |||
| 'ceit_tiny_patch16_224', 'ceit_small_patch16_224', 'ceit_base_patch16_224', | |||
| 'ceit_tiny_patch16_384', 'ceit_small_patch16_384', | |||
| ] | |||
| def _cfg(url='', **kwargs): | |||
| return { | |||
| 'url': url, | |||
| 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, | |||
| 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, | |||
| # 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, | |||
| 'first_conv': 'patch_embed.proj', 'classifier': 'head', | |||
| **kwargs | |||
| } | |||
| class Image2Tokens(nn.Cell): | |||
| """ image to tokens """ | |||
| def __init__(self, in_chans=3, out_chans=64, kernel_size=7, stride=2): | |||
| super().__init__() | |||
| self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_size, stride=stride, | |||
| pad_mode='pad', padding=kernel_size // 2, has_bias=False) | |||
| self.bn = nn.BatchNorm2d(out_chans) | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='pad', padding=1) | |||
| def construct(self, x): | |||
| x = self.conv(x) | |||
| x = self.bn(x) | |||
| x = self.maxpool(x) | |||
| return x | |||
| class Mlp(nn.Cell): | |||
| """ mlp """ | |||
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Dense(in_features, hidden_features) | |||
| self.act = act_layer() | |||
| self.fc2 = nn.Dense(hidden_features, out_features) | |||
| self.drop = nn.Dropout(p=drop) | |||
| def construct(self, x): | |||
| x = self.fc1(x) | |||
| x = self.act(x) | |||
| x = self.drop(x) | |||
| x = self.fc2(x) | |||
| x = self.drop(x) | |||
| return x | |||
| class LocallyEnhancedFeedForward(nn.Cell): | |||
| """ locally enhanced feed forward """ | |||
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, | |||
| kernel_size=3, with_bn=True): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| # pointwise | |||
| self.conv1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) | |||
| # depthwise | |||
| self.conv2 = nn.Conv2d( | |||
| hidden_features, hidden_features, kernel_size=kernel_size, stride=1, | |||
| pad_mode='pad', padding=(kernel_size - 1) // 2, group=hidden_features | |||
| ) | |||
| # pointwise | |||
| self.conv3 = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1) | |||
| self.act = act_layer() | |||
| # self.drop = nn.Dropout(drop) | |||
| self.with_bn = with_bn | |||
| if self.with_bn: | |||
| self.bn1 = nn.BatchNorm2d(hidden_features) | |||
| self.bn2 = nn.BatchNorm2d(hidden_features) | |||
| self.bn3 = nn.BatchNorm2d(out_features) | |||
| def construct(self, x): | |||
| b, n, k = x.shape | |||
| cls_token, tokens = ms.ops.split(x, [1, n - 1], axis=1) | |||
| x = tokens.reshape(b, int(math.sqrt(n - 1)), int(math.sqrt(n - 1)), k).permute(0, 3, 1, 2) | |||
| if self.with_bn: | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.act(x) | |||
| x = self.conv2(x) | |||
| x = self.bn2(x) | |||
| x = self.act(x) | |||
| x = self.conv3(x) | |||
| x = self.bn3(x) | |||
| else: | |||
| x = self.conv1(x) | |||
| x = self.act(x) | |||
| x = self.conv2(x) | |||
| x = self.act(x) | |||
| x = self.conv3(x) | |||
| tokens = x.flatten(start_dim=2).permute(0, 2, 1) | |||
| out = ms.ops.cat((cls_token, tokens), axis=1) | |||
| return out | |||
| class Attention(nn.Cell): | |||
| """ self attention """ | |||
| def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): | |||
| super().__init__() | |||
| self.num_heads = num_heads | |||
| head_dim = dim // num_heads | |||
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |||
| self.scale = qk_scale or head_dim ** -0.5 | |||
| self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) | |||
| self.attn_drop = nn.Dropout(p=attn_drop) | |||
| self.proj = nn.Dense(dim, dim) | |||
| self.proj_drop = nn.Dropout(p=proj_drop) | |||
| self.attention_map = None | |||
| def construct(self, x): | |||
| B, N, C = x.shape | |||
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
| q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | |||
| attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale | |||
| attn = ms.ops.softmax(attn, axis=-1) | |||
| attn = self.attn_drop(attn) | |||
| x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, N, C) | |||
| x = self.proj(x) | |||
| x = self.proj_drop(x) | |||
| return x | |||
| class AttentionLCA(Attention): | |||
| """ attention lca """ | |||
| def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): | |||
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) | |||
| self.dim = dim | |||
| self.qkv_bias = qkv_bias | |||
| def construct(self, x): | |||
| q_weight = self.qkv.weight[:self.dim, :] | |||
| q_bias = None if not self.qkv_bias else self.qkv.bias[:self.dim] | |||
| kv_weight = self.qkv.weight[self.dim:, :] | |||
| kv_bias = None if not self.qkv_bias else self.qkv.bias[self.dim:] | |||
| B, N, C = x.shape | |||
| _, last_token = ms.ops.split(x, [N - 1, 1], axis=1) | |||
| q = (last_token @ q_weight.T + q_bias) \ | |||
| .reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) | |||
| kv = (x @ kv_weight.T + kv_bias) \ | |||
| .reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
| k, v = kv[0], kv[1] | |||
| attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale | |||
| attn = ms.ops.softmax(attn, axis=-1) | |||
| # self.attention_map = attn | |||
| attn = self.attn_drop(attn) | |||
| x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, 1, C) | |||
| x = self.proj(x) | |||
| x = self.proj_drop(x) | |||
| return x | |||
| class Block(nn.Cell): | |||
| """ lca blocks """ | |||
| def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., | |||
| drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3, with_bn=True, | |||
| feedforward_type='leff'): | |||
| super().__init__() | |||
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |||
| self.norm2 = norm_layer((dim, )) | |||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||
| self.norm1 = norm_layer((dim, )) | |||
| self.feedforward_type = feedforward_type | |||
| if feedforward_type == 'leff': | |||
| self.attn = Attention( | |||
| dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |||
| self.leff = LocallyEnhancedFeedForward( | |||
| in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, | |||
| kernel_size=kernel_size, with_bn=with_bn, | |||
| ) | |||
| else: # LCA | |||
| self.attn = AttentionLCA( | |||
| dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |||
| self.feedforward = Mlp( | |||
| in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop | |||
| ) | |||
| def construct(self, x): | |||
| if self.feedforward_type == 'leff': | |||
| x = x + self.drop_path(self.attn(self.norm1(x))) | |||
| x = x + self.drop_path(self.leff(self.norm2(x))) | |||
| return x, x[:, 0] | |||
| # LCA | |||
| _, last_token = ms.ops.split(x, [x.shape[1] - 1, 1], axis=1) | |||
| x = last_token + self.drop_path(self.attn(self.norm1(x))) | |||
| x = x + self.drop_path(self.feedforward(self.norm2(x))) | |||
| return x | |||
| class HybridEmbed(nn.Cell): | |||
| """ CNN Feature Map Embedding | |||
| Extract feature map from CNN, flatten, project to embedding dim. | |||
| """ | |||
| def __init__(self, backbone, img_size=224, patch_size=16, feature_size=None, in_chans=3, embed_dim=768): | |||
| super().__init__() | |||
| assert isinstance(backbone, nn.Cell) | |||
| img_size = to_2tuple(img_size) | |||
| self.img_size = img_size | |||
| self.backbone = backbone | |||
| if feature_size is None: | |||
| # map for all networks, the feature metadata has reliable channel and stride info, but using | |||
| # stride to calc feature dim requires info about padding of each stage that isn't captured. | |||
| training = backbone.training | |||
| if training: | |||
| backbone.eval() | |||
| o = self.backbone(ms.ops.zeros((1, in_chans, img_size[0], img_size[1]))) | |||
| if isinstance(o, (list, tuple)): | |||
| o = o[-1] # last feature if backbone outputs list/tuple of features | |||
| feature_size = o.shape[-2:] | |||
| feature_dim = o.shape[1] | |||
| # backbone.train(training) | |||
| else: | |||
| feature_size = to_2tuple(feature_size) | |||
| feature_dim = self.backbone.feature_info.channels()[-1] | |||
| self.num_patches = (feature_size[0] // patch_size) * (feature_size[1] // patch_size) | |||
| self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
| def construct(self, x): | |||
| x = self.backbone(x) | |||
| if isinstance(x, (list, tuple)): | |||
| x = x[-1] # last feature if backbone outputs list/tuple of features | |||
| x = self.proj(x).flatten(start_dim=2).transpose(0, 2, 1) | |||
| return x | |||
| class CeIT(nn.Cell): | |||
| """ CeIT """ | |||
| def __init__(self, | |||
| img_size=224, | |||
| patch_size=16, | |||
| in_chans=3, | |||
| num_classes=1000, | |||
| embed_dim=768, | |||
| depth=12, | |||
| num_heads=12, | |||
| mlp_ratio=4., | |||
| qkv_bias=False, | |||
| qk_scale=None, | |||
| drop_rate=0., | |||
| attn_drop_rate=0., | |||
| drop_path_rate=0., | |||
| hybrid_backbone=None, | |||
| norm_layer=nn.LayerNorm, | |||
| leff_local_size=3, | |||
| leff_with_bn=True): | |||
| """ | |||
| args: | |||
| - img_size (:obj:`int`): input image size | |||
| - patch_size (:obj:`int`): patch size | |||
| - in_chans (:obj:`int`): input channels | |||
| - num_classes (:obj:`int`): number of classes | |||
| - embed_dim (:obj:`int`): embedding dimensions for tokens | |||
| - depth (:obj:`int`): depth of encoder | |||
| - num_heads (:obj:`int`): number of heads in multi-head self-attention | |||
| - mlp_ratio (:obj:`float`): expand ratio in feedforward | |||
| - qkv_bias (:obj:`bool`): whether to add bias for mlp of qkv | |||
| - qk_scale (:obj:`float`): scale ratio for qk, default is head_dim ** -0.5 | |||
| - drop_rate (:obj:`float`): dropout rate in feedforward module after linear operation | |||
| and projection drop rate in attention | |||
| - attn_drop_rate (:obj:`float`): dropout rate for attention | |||
| - drop_path_rate (:obj:`float`): drop_path rate after attention | |||
| - hybrid_backbone (:obj:`nn.Module`): backbone e.g. resnet | |||
| - norm_layer (:obj:`nn.Module`): normalization type | |||
| - leff_local_size (:obj:`int`): kernel size in LocallyEnhancedFeedForward | |||
| - leff_with_bn (:obj:`bool`): whether add bn in LocallyEnhancedFeedForward | |||
| """ | |||
| super().__init__() | |||
| self.num_classes = num_classes | |||
| self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |||
| self.i2t = HybridEmbed( | |||
| hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) | |||
| num_patches = self.i2t.num_patches | |||
| self.cls_token = ms.Parameter(ms.ops.zeros((1, 1, embed_dim))) | |||
| self.pos_embed = ms.Parameter(ms.ops.zeros((1, num_patches + 1, embed_dim))) | |||
| self.pos_drop = nn.Dropout(p=drop_rate) | |||
| dpr = list(ms.ops.linspace(0, drop_path_rate, depth)) # stochastic depth decay rule | |||
| self.blocks = nn.SequentialCell([ | |||
| Block( | |||
| dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | |||
| drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, | |||
| kernel_size=leff_local_size, with_bn=leff_with_bn) | |||
| for i in range(depth)]) | |||
| # without droppath | |||
| self.lca = Block( | |||
| dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | |||
| drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., norm_layer=norm_layer, | |||
| feedforward_type='lca' | |||
| ) | |||
| self.pos_layer_embed = ms.Parameter(ms.ops.zeros((1, depth, embed_dim))) | |||
| self.norm = norm_layer((embed_dim, )) | |||
| # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here | |||
| # Classifier head | |||
| self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |||
| self.pos_embed = initializer(TruncatedNormal(sigma=.02), self.pos_embed.shape, self.pos_embed.dtype) | |||
| self.cls_token = initializer(TruncatedNormal(sigma=.02), self.cls_token.shape, self.cls_token.dtype) | |||
| self.apply(self._init_weights) | |||
| def _init_weights(self, cell): | |||
| if isinstance(cell, nn.Dense): | |||
| cell.weight.set_data(initializer(TruncatedNormal(sigma=.02), cell.weight.shape, cell.weight.dtype)) | |||
| if cell.bias is not None: | |||
| cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) | |||
| elif isinstance(cell, nn.LayerNorm): | |||
| cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype)) | |||
| cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype)) | |||
| def get_classifier(self): | |||
| """ get classifier """ | |||
| return self.head | |||
| def reset_classifier(self, num_classes): | |||
| """ reset classifier """ | |||
| self.num_classes = num_classes | |||
| self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |||
| def construct_features(self, x): | |||
| """ construct features """ | |||
| B = x.shape[0] | |||
| x = self.i2t(x) | |||
| cls_tokens = self.cls_token | |||
| x = ms.ops.cat((cls_tokens, x), axis=1) | |||
| x = x + self.pos_embed | |||
| x = self.pos_drop(x) | |||
| cls_token_list = [] | |||
| for blk in self.blocks: | |||
| x, curr_cls_token = blk(x) | |||
| cls_token_list.append(curr_cls_token) | |||
| all_cls_token = ms.ops.stack(cls_token_list, axis=1) # B*D*K | |||
| all_cls_token = all_cls_token + self.pos_layer_embed | |||
| # attention over cls tokens | |||
| last_cls_token = self.lca(all_cls_token) | |||
| last_cls_token = self.norm(last_cls_token) | |||
| return last_cls_token.view(B, -1) | |||
| def construct(self, x): | |||
| x = self.construct_features(x) | |||
| x = self.head(x) | |||
| return x | |||
| @register_model | |||
| def ceit_tiny_patch16_224(pretrained=False, **kwargs): | |||
| """ | |||
| convolutional + pooling stem | |||
| local enhanced feedforward | |||
| attention over cls_tokens | |||
| """ | |||
| default_cfg = _cfg(**kwargs) | |||
| i2t = Image2Tokens() | |||
| model = CeIT( | |||
| hybrid_backbone=i2t, | |||
| patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, | |||
| norm_layer=nn.LayerNorm, **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg) | |||
| return model | |||
| @register_model | |||
| def ceit_small_patch16_224(pretrained=False, **kwargs): | |||
| """ | |||
| convolutional + pooling stem | |||
| local enhanced feedforward | |||
| attention over cls_tokens | |||
| """ | |||
| default_cfg = _cfg() | |||
| i2t = Image2Tokens() | |||
| model = CeIT( | |||
| hybrid_backbone=i2t, | |||
| patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, | |||
| norm_layer=nn.LayerNorm, **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg) | |||
| return model | |||
| @register_model | |||
| def ceit_base_patch16_224(pretrained=False, **kwargs): | |||
| """ | |||
| convolutional + pooling stem | |||
| local enhanced feedforward | |||
| attention over cls_tokens | |||
| """ | |||
| default_cfg = _cfg() | |||
| i2t = Image2Tokens() | |||
| model = CeIT( | |||
| hybrid_backbone=i2t, | |||
| patch_size=4, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, | |||
| norm_layer=nn.LayerNorm, **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg) | |||
| return model | |||
| @register_model | |||
| def ceit_tiny_patch16_384(pretrained=False, **kwargs): | |||
| """ | |||
| convolutional + pooling stem | |||
| local enhanced feedforward | |||
| attention over cls_tokens | |||
| """ | |||
| default_cfg = _cfg() | |||
| i2t = Image2Tokens() | |||
| model = CeIT( | |||
| hybrid_backbone=i2t, img_size=384, | |||
| patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, | |||
| norm_layer=nn.LayerNorm, **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg) | |||
| return model | |||
| @register_model | |||
| def ceit_small_patch16_384(pretrained=False, **kwargs): | |||
| """ | |||
| convolutional + pooling stem | |||
| local enhanced feedforward | |||
| attention over cls_tokens | |||
| """ | |||
| default_cfg = _cfg() | |||
| i2t = Image2Tokens() | |||
| model = CeIT( | |||
| hybrid_backbone=i2t, img_size=384, | |||
| patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, | |||
| norm_layer=nn.LayerNorm, **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg) | |||
| return model | |||
| if __name__ == '__main__': | |||
| dummy_input = ms.ops.randn((1, 3, 224, 224)) | |||
| ceit = ceit_tiny_patch16_224() | |||
| output = ceit(dummy_input) | |||
| print(output.shape) | |||