| @@ -0,0 +1,253 @@ | |||
| # pylint: disable=E0401 | |||
| # pylint: disable=W0201 | |||
| """ | |||
| MindSpore implementation of 'vip_mlp' | |||
| Refer to "Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition" | |||
| """ | |||
| import mindspore as ms | |||
| from mindspore import nn | |||
| from mindspore.common.initializer import TruncatedNormal, initializer | |||
| from model.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |||
| from model.layers import DropPath | |||
| from model.registry import register_model | |||
| def _cfg(url='', **kwargs): | |||
| return { | |||
| 'url': url, | |||
| 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, | |||
| 'crop_pct': .96, 'interpolation': 'bicubic', | |||
| 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', | |||
| **kwargs | |||
| } | |||
| default_cfgs = { | |||
| 'ViP_S': _cfg(crop_pct=0.9), | |||
| 'ViP_M': _cfg(crop_pct=0.9), | |||
| 'ViP_L': _cfg(crop_pct=0.875), | |||
| } | |||
| class Mlp(nn.Cell): | |||
| """ mlp """ | |||
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Dense(in_features, hidden_features) | |||
| self.act = act_layer() | |||
| self.fc2 = nn.Dense(hidden_features, out_features) | |||
| self.drop = nn.Dropout(p=drop) | |||
| def construct(self, x): | |||
| x = self.drop(self.act(self.fc1(x))) | |||
| return self.drop(self.act(self.fc2(x))) | |||
| class WeightedPermuteMLP(nn.Cell): | |||
| """ weighted permute mlp """ | |||
| def __init__(self, dim, segment_dim=8, qkv_bias=False, proj_drop=0.): | |||
| super().__init__() | |||
| self.segment_dim = segment_dim | |||
| self.mlp_c = nn.Dense(dim, dim, has_bias=qkv_bias) | |||
| self.mlp_h = nn.Dense(dim, dim, has_bias=qkv_bias) | |||
| self.mlp_w = nn.Dense(dim, dim, has_bias=qkv_bias) | |||
| self.flatten = nn.Flatten(2, 3) | |||
| self.reweight = Mlp(dim, dim // 4, dim * 3) | |||
| self.proj = nn.Dense(dim, dim) | |||
| self.proj_drop = nn.Dropout(p=proj_drop) | |||
| def construct(self, x): | |||
| B, H, W, C = x.shape | |||
| S = C // self.segment_dim | |||
| h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S) | |||
| h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C) | |||
| w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S) | |||
| w = self.mlp_h(h).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C) | |||
| c = self.mlp_c(x) | |||
| a = self.flatten((h + w + c).permute(0, 3, 1, 2)).mean(2) | |||
| a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1) | |||
| a = ms.ops.softmax(a, axis=0).unsqueeze(2).unsqueeze(2) | |||
| x = h * a[0] + w * a[1] + c * a[2] | |||
| return self.proj_drop(self.proj(x)) | |||
| class PermutatorBlock(nn.Cell): | |||
| """ permutator block """ | |||
| def __init__(self, dim, segment_dim, mlp_ratio=4., qkv_bias=False, attn_drop=0., drop_path=0., | |||
| act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=WeightedPermuteMLP): | |||
| super().__init__() | |||
| self.norm1 = norm_layer((dim,)) | |||
| self.attn = mlp_fn(dim, segment_dim, qkv_bias, attn_drop) | |||
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |||
| self.norm2 = norm_layer((dim,)) | |||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||
| self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer) | |||
| self.skip_lam = skip_lam | |||
| def construct(self, x): | |||
| x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam | |||
| return x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam | |||
| class PatchEmbed(nn.Cell): | |||
| """ patch embed """ | |||
| def __init__(self, patch_size=16, in_chans=3, embed_dim=768): | |||
| super().__init__() | |||
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
| def construct(self, x): | |||
| return self.proj(x) | |||
| class Downsample(nn.Cell): | |||
| """ downsample """ | |||
| def __init__(self, in_embed_dim, out_embed_dim, patch_size): | |||
| super().__init__() | |||
| self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) | |||
| def construct(self, x): | |||
| return self.proj(x) | |||
| def basic_blocks(dim, index, layers, segment_dim, mlp_ratio=3., qkv_bias=False, attn_drop=0, | |||
| drop_path_rate=0., skip_lam=1.0, mlp_fn=WeightedPermuteMLP): | |||
| """ basic blocks """ | |||
| blocks = [] | |||
| for block_idx in range(layers[index]): | |||
| block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) | |||
| blocks.append(PermutatorBlock(dim, segment_dim, mlp_ratio, qkv_bias, | |||
| attn_drop, drop_path=block_dpr, skip_lam=skip_lam, mlp_fn=mlp_fn)) | |||
| blocks = nn.SequentialCell(*blocks) | |||
| return blocks | |||
| class VisionPermutator(nn.Cell): | |||
| """ vision permutator """ | |||
| def __init__(self, layers, patch_size=4, in_chans=3, num_classes=1000, | |||
| embed_dims=None, transitions=None, segment_dim=None, mlp_ratios=None, skip_lam=1.0, | |||
| qkv_bias=False, attn_drop_rate=0., drop_path_rate=0., | |||
| norm_layer=nn.LayerNorm): | |||
| super().__init__() | |||
| self.num_classes = num_classes | |||
| self.patch_embed = PatchEmbed(patch_size, in_chans, embed_dims[0]) | |||
| network = [] | |||
| for i in range(len(layers)): | |||
| stage = basic_blocks(embed_dims[i], i, layers, segment_dim[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, | |||
| attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, skip_lam=skip_lam) | |||
| network.append(stage) | |||
| if i >= len(layers) - 1: | |||
| break | |||
| if transitions[i] or embed_dims[i] != embed_dims[i + 1]: | |||
| patch_size = 2 if transitions[i] else 1 | |||
| network.append(Downsample(embed_dims[i], embed_dims[i + 1], patch_size)) | |||
| self.network = nn.SequentialCell(network) | |||
| self.norm = norm_layer((embed_dims[-1],)) | |||
| self.head = nn.Dense(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() | |||
| self.apply(self._init_weights) | |||
| def _init_weights(self, cell): | |||
| """ initialize weights """ | |||
| if isinstance(cell, nn.Dense): | |||
| cell.weight.set_data(initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype)) | |||
| if cell.bias is not None: | |||
| cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) | |||
| if isinstance(cell, nn.LayerNorm): | |||
| cell.gamma.set_data(initializer('zeros', cell.gamma.shape, cell.gamma.dtype)) | |||
| cell.beta.set_data(initializer('ones', cell.beta.shape, cell.beta.dtype)) | |||
| def forward_embeddings(self, x): | |||
| """ forward embeddings """ | |||
| x = self.patch_embed(x) | |||
| x = x.permute(0, 2, 3, 1) | |||
| return x | |||
| def forward_tokens(self, x): | |||
| """ forward tokens """ | |||
| for _, block in enumerate(self.network): | |||
| x = block(x) | |||
| B, _, _, C = x.shape | |||
| x = x.reshape(B, -1, C) | |||
| return x | |||
| def construct(self, x): | |||
| x = self.forward_embeddings(x) | |||
| x = self.forward_tokens(x) | |||
| x = self.norm(x) | |||
| return self.head(x.mean(1)) | |||
| @register_model | |||
| def vip_s14(**kwargs): | |||
| """ vip s14 """ | |||
| layers = [4, 3, 8, 3] | |||
| transitions = [False, False, False, False] | |||
| segment_dim = [16, 16, 16, 16] | |||
| mlp_ratios = [3, 3, 3, 3] | |||
| embed_dims = [384, 384, 384, 384] | |||
| model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=14, transitions=transitions, | |||
| segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs) | |||
| model.default_cfg = default_cfgs['ViP_S'] | |||
| return model | |||
| @register_model | |||
| def vip_s7(**kwargs): | |||
| """ vip s7 """ | |||
| layers = [4, 3, 8, 3] | |||
| transitions = [True, False, False, False] | |||
| segment_dim = [32, 16, 16, 16] | |||
| mlp_ratios = [3, 3, 3, 3] | |||
| embed_dims = [192, 384, 384, 384] | |||
| model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, | |||
| segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs) | |||
| model.default_cfg = default_cfgs['ViP_S'] | |||
| return model | |||
| @register_model | |||
| def vip_m7(**kwargs): | |||
| """ vip m7 """ | |||
| layers = [4, 3, 14, 3] | |||
| transitions = [False, True, False, False] | |||
| segment_dim = [32, 32, 16, 16] | |||
| mlp_ratios = [3, 3, 3, 3] | |||
| embed_dims = [256, 256, 512, 512] | |||
| model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, | |||
| segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs) | |||
| model.default_cfg = default_cfgs['ViP_M'] | |||
| return model | |||
| @register_model | |||
| def vip_l7(**kwargs): | |||
| """ vip l7 """ | |||
| layers = [8, 8, 16, 4] | |||
| transitions = [True, False, False, False] | |||
| segment_dim = [32, 16, 16, 16] | |||
| mlp_ratios = [3, 3, 3, 3] | |||
| embed_dims = [256, 512, 512, 512] | |||
| model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, | |||
| segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs) | |||
| model.default_cfg = default_cfgs['ViP_L'] | |||
| return model | |||
| if __name__ == "__main__": | |||
| dummy_input = ms.ops.randn((1, 3, 224, 224)) | |||
| vip = vip_s14() | |||
| output = vip(dummy_input) | |||
| print(output.shape) | |||