Browse Source

vip_mlp

pull/4/head
Huyf9 2 years ago
parent
commit
d43c05b5b8
1 changed files with 253 additions and 0 deletions
  1. +253
    -0
      model/mlp/vip_mlp.py

+ 253
- 0
model/mlp/vip_mlp.py View File

@@ -0,0 +1,253 @@
# pylint: disable=E0401
# pylint: disable=W0201
"""
MindSpore implementation of 'vip_mlp'
Refer to "Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition"
"""
import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import TruncatedNormal, initializer

from model.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from model.layers import DropPath
from model.registry import register_model


def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .96, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
**kwargs
}


default_cfgs = {
'ViP_S': _cfg(crop_pct=0.9),
'ViP_M': _cfg(crop_pct=0.9),
'ViP_L': _cfg(crop_pct=0.875),
}


class Mlp(nn.Cell):
""" mlp """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Dense(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Dense(hidden_features, out_features)
self.drop = nn.Dropout(p=drop)

def construct(self, x):
x = self.drop(self.act(self.fc1(x)))
return self.drop(self.act(self.fc2(x)))


class WeightedPermuteMLP(nn.Cell):
""" weighted permute mlp """
def __init__(self, dim, segment_dim=8, qkv_bias=False, proj_drop=0.):
super().__init__()
self.segment_dim = segment_dim
self.mlp_c = nn.Dense(dim, dim, has_bias=qkv_bias)
self.mlp_h = nn.Dense(dim, dim, has_bias=qkv_bias)
self.mlp_w = nn.Dense(dim, dim, has_bias=qkv_bias)

self.flatten = nn.Flatten(2, 3)

self.reweight = Mlp(dim, dim // 4, dim * 3)
self.proj = nn.Dense(dim, dim)
self.proj_drop = nn.Dropout(p=proj_drop)

def construct(self, x):
B, H, W, C = x.shape

S = C // self.segment_dim
h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S)
h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)

w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S)
w = self.mlp_h(h).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)

c = self.mlp_c(x)

a = self.flatten((h + w + c).permute(0, 3, 1, 2)).mean(2)
a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1)
a = ms.ops.softmax(a, axis=0).unsqueeze(2).unsqueeze(2)

x = h * a[0] + w * a[1] + c * a[2]

return self.proj_drop(self.proj(x))


class PermutatorBlock(nn.Cell):
""" permutator block """
def __init__(self, dim, segment_dim, mlp_ratio=4., qkv_bias=False, attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=WeightedPermuteMLP):
super().__init__()
self.norm1 = norm_layer((dim,))
self.attn = mlp_fn(dim, segment_dim, qkv_bias, attn_drop)

self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

self.norm2 = norm_layer((dim,))
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer)
self.skip_lam = skip_lam

def construct(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam
return x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam


class PatchEmbed(nn.Cell):
""" patch embed """
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

def construct(self, x):
return self.proj(x)


class Downsample(nn.Cell):
""" downsample """
def __init__(self, in_embed_dim, out_embed_dim, patch_size):
super().__init__()
self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size)

def construct(self, x):
return self.proj(x)


def basic_blocks(dim, index, layers, segment_dim, mlp_ratio=3., qkv_bias=False, attn_drop=0,
drop_path_rate=0., skip_lam=1.0, mlp_fn=WeightedPermuteMLP):
""" basic blocks """
blocks = []

for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(PermutatorBlock(dim, segment_dim, mlp_ratio, qkv_bias,
attn_drop, drop_path=block_dpr, skip_lam=skip_lam, mlp_fn=mlp_fn))
blocks = nn.SequentialCell(*blocks)
return blocks


class VisionPermutator(nn.Cell):
""" vision permutator """
def __init__(self, layers, patch_size=4, in_chans=3, num_classes=1000,
embed_dims=None, transitions=None, segment_dim=None, mlp_ratios=None, skip_lam=1.0,
qkv_bias=False, attn_drop_rate=0., drop_path_rate=0.,
norm_layer=nn.LayerNorm):
super().__init__()
self.num_classes = num_classes

self.patch_embed = PatchEmbed(patch_size, in_chans, embed_dims[0])

network = []
for i in range(len(layers)):
stage = basic_blocks(embed_dims[i], i, layers, segment_dim[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, skip_lam=skip_lam)
network.append(stage)
if i >= len(layers) - 1:
break
if transitions[i] or embed_dims[i] != embed_dims[i + 1]:
patch_size = 2 if transitions[i] else 1
network.append(Downsample(embed_dims[i], embed_dims[i + 1], patch_size))

self.network = nn.SequentialCell(network)
self.norm = norm_layer((embed_dims[-1],))

self.head = nn.Dense(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)

def _init_weights(self, cell):
""" initialize weights """
if isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
if isinstance(cell, nn.LayerNorm):
cell.gamma.set_data(initializer('zeros', cell.gamma.shape, cell.gamma.dtype))
cell.beta.set_data(initializer('ones', cell.beta.shape, cell.beta.dtype))

def forward_embeddings(self, x):
""" forward embeddings """
x = self.patch_embed(x)
x = x.permute(0, 2, 3, 1)
return x

def forward_tokens(self, x):
""" forward tokens """
for _, block in enumerate(self.network):
x = block(x)
B, _, _, C = x.shape
x = x.reshape(B, -1, C)
return x

def construct(self, x):
x = self.forward_embeddings(x)
x = self.forward_tokens(x)
x = self.norm(x)
return self.head(x.mean(1))

@register_model
def vip_s14(**kwargs):
""" vip s14 """
layers = [4, 3, 8, 3]
transitions = [False, False, False, False]
segment_dim = [16, 16, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [384, 384, 384, 384]
model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=14, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs)
model.default_cfg = default_cfgs['ViP_S']
return model

@register_model
def vip_s7(**kwargs):
""" vip s7 """
layers = [4, 3, 8, 3]
transitions = [True, False, False, False]
segment_dim = [32, 16, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [192, 384, 384, 384]
model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs)
model.default_cfg = default_cfgs['ViP_S']
return model

@register_model
def vip_m7(**kwargs):
""" vip m7 """
layers = [4, 3, 14, 3]
transitions = [False, True, False, False]
segment_dim = [32, 32, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [256, 256, 512, 512]
model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs)
model.default_cfg = default_cfgs['ViP_M']
return model

@register_model
def vip_l7(**kwargs):
""" vip l7 """
layers = [8, 8, 16, 4]
transitions = [True, False, False, False]
segment_dim = [32, 16, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [256, 512, 512, 512]
model = VisionPermutator(layers=layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, **kwargs)
model.default_cfg = default_cfgs['ViP_L']
return model


if __name__ == "__main__":
dummy_input = ms.ops.randn((1, 3, 224, 224))
vip = vip_s14()
output = vip(dummy_input)
print(output.shape)

Loading…
Cancel
Save