| @@ -1,582 +0,0 @@ | |||
| # pylint: disable=E0401 | |||
| # pylint: disable=W0201 | |||
| """ | |||
| MindSpore implementation of `LeViT`. | |||
| Refer to "LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference" | |||
| """ | |||
| import itertools | |||
| import mindspore as ms | |||
| from mindspore import Tensor | |||
| from mindspore import nn | |||
| from mindspore import context | |||
| from mindspore import numpy as np | |||
| from mindspore.common import initializer as init | |||
| from model.helper import load_pretrained | |||
| from model.registry import register_model | |||
| __all__ = [ | |||
| "LeViT", | |||
| "LeViT_128S", | |||
| "LeViT_128", | |||
| "LeViT_192", | |||
| "LeViT_256", | |||
| "LeViT_384", | |||
| ] | |||
| def _cfg(url='', **kwargs): # need to check for | |||
| return { | |||
| 'url': url, | |||
| 'num_classes': 1000, | |||
| 'input_size': (3, 224, 224), | |||
| # 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, | |||
| 'first_conv': 'patch_embed.proj', 'classifier': 'head', | |||
| **kwargs | |||
| } | |||
| default_cfgs = { | |||
| 'LeViT_128S': _cfg(url=''), | |||
| 'LeViT_128': _cfg(url=''), | |||
| 'LeViT_192': _cfg(url=''), | |||
| 'LeViT_256': _cfg(url=''), | |||
| 'LeViT_384': _cfg(url='') | |||
| } | |||
| FLOPS_COUNTER = 0 | |||
| class Conv2d_BN(nn.SequentialCell): | |||
| """ Conv2d and BatchNorm """ | |||
| def __init__(self, | |||
| a: int, | |||
| b: int, | |||
| ks: int = 1, | |||
| stride: int = 1, | |||
| pad: int = 0, # pad=1 | |||
| dilation: int = 1, | |||
| group: int = 1) -> None: | |||
| super().__init__() | |||
| self.conv = nn.Conv2d(in_channels=a, | |||
| out_channels=b, | |||
| kernel_size=ks, | |||
| stride=stride, | |||
| padding=pad, | |||
| dilation=dilation, | |||
| group=group, | |||
| has_bias=False, | |||
| pad_mode="pad") | |||
| self.bn = nn.BatchNorm2d(num_features=b, | |||
| gamma_init="ones", | |||
| beta_init="zeros", | |||
| use_batch_statistics=True, | |||
| momentum=0.9) # 0.1 | |||
| def construct(self, input_data: Tensor) -> Tensor: | |||
| x = self.conv(input_data) | |||
| x = self.bn(x) | |||
| return x | |||
| class Linear_BN(nn.SequentialCell): | |||
| """ Dense and BatchNorm """ | |||
| def __init__(self, | |||
| a: int, | |||
| b: int) -> None: | |||
| super().__init__() | |||
| self.linear = nn.Dense(a, | |||
| b, | |||
| weight_init='Uniform', | |||
| bias_init='Uniform', | |||
| has_bias=False) | |||
| self.bn1d = nn.BatchNorm1d(num_features=b, | |||
| gamma_init="ones", | |||
| beta_init="zeros", | |||
| momentum=0.9) | |||
| def construct(self, input_data: Tensor) -> Tensor: | |||
| x = self.linear(input_data) | |||
| x1, x2, x3 = x.shape | |||
| new_x = ms.ops.reshape(x, (x1 * x2, x3)) | |||
| x = self.bn1d(new_x).reshape(x.shape) | |||
| return x | |||
| class BN_Linear(nn.SequentialCell): | |||
| """ BatchNorm and Dense """ | |||
| def __init__(self, | |||
| a: int, | |||
| b: int, | |||
| bias: bool = True, | |||
| std: float = 0.02) -> None: | |||
| super().__init__() | |||
| self.bn1d = nn.BatchNorm1d(num_features=a, | |||
| gamma_init="ones", | |||
| beta_init="zeros", | |||
| momentum=0.9) | |||
| self.linear = nn.Dense(a, | |||
| b, | |||
| weight_init=init.TruncatedNormal(sigma=std), | |||
| bias_init='zeros', | |||
| has_bias=bias) | |||
| def construct(self, input_data: Tensor) -> Tensor: | |||
| x = self.bn1d(input_data) | |||
| x = self.linear(x) | |||
| return x | |||
| class Residual(nn.Cell): | |||
| """ Residual """ | |||
| def __init__(self, | |||
| m: type = None, | |||
| drop: int = 0): | |||
| super().__init__() | |||
| self.m = m | |||
| self.drop = drop | |||
| def construct(self, x: Tensor) -> Tensor: | |||
| if self.training and self.drop > 0: | |||
| return x + self.m(x) * ms.Tensor.to_tensor( | |||
| (np.randn((x.shape[0], 1, 1)) > self.drop) / (1 - self.drop)) | |||
| y = self.m(x) | |||
| x = x + y | |||
| return x | |||
| def b16(n, activation=nn.HSwish): | |||
| """ b16 """ | |||
| return nn.SequentialCell( | |||
| Conv2d_BN(3, n // 8, 3, 2, 1), | |||
| activation(), | |||
| Conv2d_BN(n // 8, n // 4, 3, 2, 1), | |||
| activation(), | |||
| Conv2d_BN(n // 4, n // 2, 3, 2, 1), | |||
| activation(), | |||
| Conv2d_BN(n // 2, n, 3, 2, 1)) | |||
| class Subsample(nn.Cell): | |||
| """ DownSample """ | |||
| def __init__(self, | |||
| stride: int, | |||
| resolution: int): | |||
| super().__init__() | |||
| self.stride = stride | |||
| self.resolution = resolution | |||
| def construct(self, x: Tensor) -> Tensor: | |||
| B, _, C = x.shape | |||
| x = x.view(B, self.resolution, self.resolution, C)[ | |||
| :, ::self.stride, ::self.stride].reshape(B, -1, C) | |||
| return x | |||
| class Attention(nn.Cell): | |||
| """ Attention """ | |||
| def __init__(self, | |||
| dim: int, | |||
| key_dim: int, | |||
| num_heads: int = 8, | |||
| attn_ratio: int = 4, | |||
| activation: type = None, | |||
| resolution: int = 14) -> None: | |||
| super().__init__() | |||
| self.num_heads = num_heads | |||
| self.scale = key_dim ** -0.5 | |||
| self.key_dim = key_dim | |||
| self.nh_kd = nh_kd = key_dim * num_heads | |||
| self.d = int(attn_ratio * key_dim) | |||
| self.dh = int(attn_ratio * key_dim) * num_heads | |||
| self.attn_ratio = attn_ratio | |||
| h = self.dh + nh_kd * 2 | |||
| self.qkv = Linear_BN(dim, h) | |||
| self.proj = nn.SequentialCell(activation(), Linear_BN(self.dh, dim)) | |||
| points = list(itertools.product(range(resolution), range(resolution))) # 迭代两个不同大小的列表来获取新列表 | |||
| self.N = len(points) | |||
| self.softmax = nn.Softmax(axis=-1) | |||
| attention_offsets = {} | |||
| idxs = [] | |||
| for p1 in points: | |||
| for p2 in points: | |||
| offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) | |||
| if offset not in attention_offsets: | |||
| attention_offsets[offset] = len(attention_offsets) | |||
| idxs.append(attention_offsets[offset]) | |||
| self.attention_biases = ms.Parameter( | |||
| Tensor(np.zeros([num_heads, len(attention_offsets)], np.float32))) | |||
| attention_bias_idxs = ms.Tensor(idxs, dtype=ms.int64).view(self.N, self.N) | |||
| self.attention_bias_idxs = ms.Parameter(attention_bias_idxs, requires_grad=False) | |||
| self.ab = self.attention_biases[:, self.attention_bias_idxs] | |||
| def construct(self, x: Tensor) -> Tensor: | |||
| B, N, _ = x.shape | |||
| atte = self.qkv(x).view(B, N, self.num_heads, -1) | |||
| q, k, v = ms.ops.split(atte, [self.key_dim, self.key_dim, 2 * self.key_dim], axis=3) | |||
| q = ms.ops.transpose(q, (0, 2, 1, 3)) | |||
| k = ms.ops.transpose(k, (0, 2, 1, 3)) | |||
| v = ms.ops.transpose(v, (0, 2, 1, 3)) | |||
| attn = ( | |||
| (ms.ops.matmul(q, ms.ops.transpose(k, (-4, -3, -1, -2)))) * self.scale | |||
| + | |||
| (self.attention_biases[:, self.attention_bias_idxs] | |||
| if self.training else self.ab) | |||
| ) | |||
| attn = self.softmax(attn) | |||
| x = ms.ops.transpose((ms.ops.matmul(attn, v)), (0, 2, 1, 3)) | |||
| x = x.reshape(B, N, self.dh) | |||
| x = self.proj(x) | |||
| return x | |||
| class AttentionSubsample(nn.Cell): | |||
| """ Attention SubSample """ | |||
| def __init__(self, | |||
| in_dim: int, | |||
| out_dim: int, | |||
| key_dim: int, | |||
| num_heads: int = 8, | |||
| attn_ratio: int = 2, | |||
| activation: type = None, | |||
| stride: int = 2, | |||
| resolution: int = 14, | |||
| resolution_: int = 7) -> None: | |||
| super().__init__() | |||
| self.num_heads = num_heads | |||
| self.scale = key_dim ** -0.5 | |||
| self.key_dim = key_dim | |||
| self.nh_kd = nh_kd = key_dim * num_heads | |||
| self.d = int(attn_ratio * key_dim) | |||
| self.dh = int(attn_ratio * key_dim) * self.num_heads | |||
| self.attn_ratio = attn_ratio | |||
| self.resolution_ = resolution_ | |||
| self.resolution_2 = resolution_ ** 2 | |||
| h = self.dh + nh_kd | |||
| self.kv = Linear_BN(in_dim, h) | |||
| self.q = nn.SequentialCell( | |||
| Subsample(stride, resolution), | |||
| Linear_BN(in_dim, nh_kd)) | |||
| self.proj = nn.SequentialCell(activation(), Linear_BN(self.dh, out_dim)) | |||
| self.softmax = nn.Softmax(axis=-1) | |||
| self.stride = stride | |||
| self.resolution = resolution | |||
| points = list(itertools.product(range(resolution), range(resolution))) | |||
| points_ = list(itertools.product(range(resolution_), range(resolution_))) | |||
| N = len(points) | |||
| N_ = len(points_) | |||
| attention_offsets = {} | |||
| idxs = [] | |||
| for p1 in points_: | |||
| for p2 in points: | |||
| size = 1 | |||
| offset = ( | |||
| abs(p1[0] * stride - p2[0] + (size - 1) / 2), | |||
| abs(p1[1] * stride - p2[1] + (size - 1) / 2)) | |||
| if offset not in attention_offsets: | |||
| attention_offsets[offset] = len(attention_offsets) | |||
| idxs.append(attention_offsets[offset]) | |||
| self.attention_biases = ms.Parameter( | |||
| Tensor(np.zeros([num_heads, len(attention_offsets)], np.float32))) | |||
| attention_bias_idxs = (ms.Tensor(idxs, dtype=ms.int64)).view((N_, N)) | |||
| self.attention_bias_idxs = ms.Parameter(attention_bias_idxs, requires_grad=False) | |||
| self.ab = self.attention_biases[:, self.attention_bias_idxs] | |||
| def construct(self, | |||
| x: Tensor) -> Tensor: | |||
| B, N, _ = x.shape | |||
| atte = self.kv(x).view(B, N, self.num_heads, -1) | |||
| k, v = ms.ops.split(atte, [self.key_dim, atte.shape[3] - self.key_dim], axis=3) | |||
| v = ms.ops.transpose(v, (0, 2, 1, 3)) | |||
| k = ms.ops.transpose(k, (0, 2, 1, 3)) | |||
| q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim) | |||
| q = ms.ops.transpose(q, (0, 2, 1, 3)) | |||
| attn = ( | |||
| ms.ops.matmul(q, ms.ops.transpose(k, (-4, -3, -1, -2))) * self.scale + | |||
| (self.attention_biases[:, self.attention_bias_idxs] | |||
| if self.training else self.ab) | |||
| ) | |||
| attn = self.softmax(attn) | |||
| x = ms.ops.transpose((ms.ops.matmul(attn, v)), (0, 2, 1, 3)) | |||
| x = x.reshape(B, -1, self.dh) | |||
| x = self.proj(x) | |||
| return x | |||
| class LeViT(nn.Cell): | |||
| """ Vision Transformer with support for patch or hybrid CNN input stage | |||
| """ | |||
| def __init__(self, | |||
| img_size: int = 224, | |||
| patch_size: int = 16, | |||
| num_classes: int = 1000, | |||
| embed_dim=None, | |||
| key_dim=None, | |||
| depth=None, | |||
| num_heads=None, | |||
| attn_ratio=None, | |||
| mlp_ratio=None, | |||
| hybrid_backbone: type = b16(128, activation=nn.HSwish), | |||
| down_ops=None, | |||
| attention_activation: type = nn.HSwish, | |||
| mlp_activation: type = nn.HSwish, | |||
| distillation: bool = True, | |||
| drop_path: int = 0): | |||
| super().__init__() | |||
| if embed_dim is None: | |||
| embed_dim = [128, 256, 384] | |||
| if key_dim is None: | |||
| key_dim = [16, 16, 16] | |||
| if depth is None: | |||
| depth = [2, 3, 4] | |||
| if num_heads is None: | |||
| num_heads = [4, 6, 8] | |||
| if attn_ratio is None: | |||
| attn_ratio = [2, 2, 2] | |||
| if mlp_ratio is None: | |||
| mlp_ratio = [2, 2, 2] | |||
| if down_ops is None: | |||
| down_ops = [['Subsample', 16, 128 // 16, 4, 2, 2], ['Subsample', 16, 256 // 16, 4, 2, 2]] | |||
| self.num_classes = num_classes | |||
| self.num_features = embed_dim[-1] | |||
| self.embed_dim = embed_dim | |||
| self.distillation = distillation | |||
| self.patch_embed = hybrid_backbone | |||
| self.blocks = [] | |||
| down_ops.append(['']) | |||
| resolution = img_size // patch_size | |||
| for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( | |||
| zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): | |||
| for _ in range(dpth): | |||
| self.blocks.append( | |||
| Residual(Attention( | |||
| ed, kd, nh, | |||
| attn_ratio=ar, | |||
| activation=attention_activation, | |||
| resolution=resolution, | |||
| ), drop_path)) | |||
| if mr > 0: | |||
| h = int(ed * mr) | |||
| self.blocks.append( | |||
| Residual(nn.SequentialCell( | |||
| Linear_BN(ed, h), | |||
| mlp_activation(), | |||
| Linear_BN(h, ed), | |||
| ), drop_path)) | |||
| if do[0] == 'Subsample': | |||
| resolution_ = (resolution - 1) // do[5] + 1 | |||
| self.blocks.append( | |||
| AttentionSubsample( | |||
| *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], | |||
| attn_ratio=do[3], | |||
| activation=attention_activation, | |||
| stride=do[5], | |||
| resolution=resolution, | |||
| resolution_=resolution_)) | |||
| resolution = resolution_ | |||
| if do[4] > 0: # mlp_ratio | |||
| h = int(embed_dim[i + 1] * do[4]) | |||
| self.blocks.append( | |||
| Residual(nn.SequentialCell( | |||
| Linear_BN(embed_dim[i + 1], h), | |||
| mlp_activation(), | |||
| Linear_BN( | |||
| h, embed_dim[i + 1]), | |||
| ), drop_path)) | |||
| self.blocks = nn.SequentialCell(*self.blocks) | |||
| # Classifier head | |||
| if num_classes > 0: | |||
| self.head = BN_Linear(embed_dim[-1], num_classes) | |||
| if distillation: | |||
| self.head_dist = BN_Linear(embed_dim[-1], num_classes) | |||
| self.apply(self._init_weights) | |||
| def _init_weights(self, cell): | |||
| """ initialize weights """ | |||
| if isinstance(cell, nn.Dense): | |||
| cell.weight.set_data(init.initializer(init.TruncatedNormal(sigma=.02), cell.weight.data.shape)) | |||
| if cell.bias is not None: | |||
| cell.bias.set_data(init.initializer('zeros', cell.bias.shape)) | |||
| elif isinstance(cell, nn.LayerNorm): | |||
| cell.gamma.set_data(init.initializer('ones', cell.gamma.shape)) | |||
| cell.beta.set_data(init.initializer('zeros', cell.beta.shape)) | |||
| def construct(self, x: Tensor) -> Tensor: | |||
| x = self.patch_embed(x) | |||
| B, C, H, W = x.shape | |||
| x = x.reshape(B, C, H * W) | |||
| x = ms.ops.transpose(x, (0, 2, 1)) | |||
| x = self.blocks(x) | |||
| x = x.mean(1) | |||
| if self.distillation: | |||
| x = self.head(x), self.head_dist(x) | |||
| if not self.training: | |||
| x = (x[0] + x[1]) / 2 | |||
| else: | |||
| x = self.head(x) | |||
| return x | |||
| @register_model | |||
| def LeViT_128S(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> LeViT: | |||
| """ LeViT_128S """ | |||
| default_cfg = default_cfgs['LeViT_128S'] | |||
| model = LeViT(num_classes=num_classes, | |||
| embed_dim=[128, 256, 384], | |||
| num_heads=[4, 6, 8], | |||
| key_dim=[16, 16, 16], | |||
| depth=[2, 3, 4], | |||
| down_ops=[ | |||
| ['Subsample', 16, 128 // 16, 4, 2, 2], | |||
| ['Subsample', 16, 256 // 16, 4, 2, 2], | |||
| ], | |||
| hybrid_backbone=b16(128), | |||
| **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) | |||
| return model | |||
| @register_model | |||
| def LeViT_128(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> LeViT: | |||
| """ LeViT_128 """ | |||
| default_cfg = default_cfgs['LeViT_128'] | |||
| model = LeViT(num_classes=num_classes, | |||
| embed_dim=[128, 256, 384], | |||
| num_heads=[4, 8, 12], | |||
| key_dim=[16, 16, 16], | |||
| depth=[4, 4, 4], | |||
| down_ops=[ | |||
| ['Subsample', 16, 128 // 16, 4, 2, 2], | |||
| ['Subsample', 16, 256 // 16, 4, 2, 2], | |||
| ], | |||
| hybrid_backbone=b16(128), | |||
| **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) | |||
| return model | |||
| @register_model | |||
| def LeViT_192(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> LeViT: | |||
| """ LeViT_192 """ | |||
| default_cfg = default_cfgs['LeViT_192'] | |||
| model = LeViT(num_classes=num_classes, | |||
| embed_dim=[192, 288, 384], | |||
| num_heads=[3, 5, 6], | |||
| key_dim=[32, 32, 32], | |||
| depth=[4, 4, 4], | |||
| down_ops=[ | |||
| ['Subsample', 32, 192 // 32, 4, 2, 2], | |||
| ['Subsample', 32, 288 // 32, 4, 2, 2], | |||
| ], | |||
| hybrid_backbone=b16(192), | |||
| **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) | |||
| return model | |||
| @register_model | |||
| def LeViT_256(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> LeViT: | |||
| """ LeViT_256 """ | |||
| default_cfg = default_cfgs['LeViT_256'] | |||
| model = LeViT(num_classes=num_classes, | |||
| embed_dim=[256, 384, 512], | |||
| num_heads=[4, 6, 8], | |||
| key_dim=[32, 32, 32], | |||
| depth=[4, 4, 4], | |||
| down_ops=[ | |||
| ['Subsample', 32, 256 // 32, 4, 2, 2], | |||
| ['Subsample', 32, 384 // 32, 4, 2, 2], | |||
| ], | |||
| hybrid_backbone=b16(256), | |||
| **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) | |||
| return model | |||
| @register_model | |||
| def LeViT_384(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> LeViT: | |||
| """ LeViT_384 """ | |||
| default_cfg = default_cfgs['LeViT_384'] | |||
| model = LeViT(num_classes=num_classes, | |||
| embed_dim=[384, 512, 768], | |||
| num_heads=[6, 9, 12], | |||
| key_dim=[32, 32, 32], | |||
| depth=[4, 4, 4], | |||
| down_ops=[ | |||
| ['Subsample', 32, 384 // 32, 4, 2, 2], | |||
| ['Subsample', 32, 512 // 32, 4, 2, 2], | |||
| ], | |||
| hybrid_backbone=b16(384), | |||
| **kwargs) | |||
| model.default_cfg = default_cfg | |||
| if pretrained: | |||
| load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) | |||
| return model | |||
| if __name__ == '__main__': | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") # context.PYNATIVE_MODE | |||
| net = LeViT_128S() | |||
| # print(net) | |||
| dummy_input = ms.ops.rand((4, 3, 224, 224)) | |||
| output = net(dummy_input) | |||
| print(output.shape) | |||