|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662 |
- # pylint: disable=E0401
- # pylint: disable=W0201
- """
- MindSpore implementation of `DeiT`.
- Refer to "Training data-efficient image transformers & distillation through attention"
- """
- from enum import Enum
- from functools import partial
- from typing import Union, Tuple, Optional, Callable
-
- import mindspore as ms
- from mindspore import nn
- from mindspore.common.initializer import initializer, TruncatedNormal, Normal
-
- from model.layers import DropPath, to_2tuple
- from model.registry import register_model
- from model.helper import load_pretrained
-
- __all__ = [
- 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
- 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
- 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
- 'deit_base_distilled_patch16_384',
- ]
-
-
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic',
- # 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'patch_embed.proj', 'classifier': 'head',
- **kwargs
- }
-
-
- class Format(str, Enum):
- """ image format """
- NCHW = 'NCHW'
- NHWC = 'NHWC'
- NCL = 'NCL'
- NLC = 'NLC'
-
-
- def nchw_to(x: ms.Tensor, fmt: Format):
- """ switch image format """
- if fmt == Format.NHWC:
- x = x.permute(0, 2, 3, 1)
- elif fmt == Format.NLC:
- x = x.flatten(start_dim=2).transpose(0, 2, 1)
- elif fmt == Format.NCL:
- x = x.flatten(start_dim=2)
- return x
-
-
- class Mlp(nn.Cell):
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks
- """
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- norm_layer=None,
- bias=True,
- drop=0.,
- use_conv=False,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Dense
-
- self.fc1 = linear_layer(in_features, hidden_features, has_bias=bias[0])
- self.act = act_layer()
- self.drop1 = nn.Dropout(p=drop_probs[0])
- self.norm = norm_layer((hidden_features, )) if norm_layer is not None else nn.Identity()
- self.fc2 = linear_layer(hidden_features, out_features, has_bias=bias[1])
- self.drop2 = nn.Dropout(p=drop_probs[1])
-
- def construct(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
-
-
- class PatchDropout(nn.Cell):
- """ https://arxiv.org/abs/2212.00794
- """
- def __init__(
- self,
- prob: float = 0.5,
- num_prefix_tokens: int = 1,
- ordered: bool = False,
- return_indices: bool = False,
- ):
- super().__init__()
- assert 0 <= prob < 1.
- self.prob = prob
- self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
- self.ordered = ordered
- self.return_indices = return_indices
-
- def construct(self, x):
- if not self.training or self.prob == 0.:
- if self.return_indices:
- return x, None
- return x
-
- if self.num_prefix_tokens:
- prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
- else:
- prefix_tokens = None
-
- B = x.shape[0]
- L = x.shape[1]
- num_keep = max(1, int(L * (1. - self.prob)))
- keep_indices = ms.ops.argsort(ms.ops.randn((B, L)), axis=-1)[:, :num_keep]
- if self.ordered:
- # NOTE does not need to maintain patch order in typical transformer use,
- # but possibly useful for debug / visualization
- keep_indices = keep_indices.sort(dim=-1)[0]
- x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
-
- if prefix_tokens is not None:
- x = ms.ops.cat((prefix_tokens, x), axis=1)
-
- if self.return_indices:
- return x, keep_indices
- return x
-
-
- class LayerScale(nn.Cell):
- """ Layer Scale """
- def __init__(self, dim, init_values=1e-5):
- super().__init__()
- self.gamma = ms.Parameter(init_values * ms.ops.ones((dim, )))
-
- def construct(self, x):
- return x * self.gamma
-
-
- class Attention(nn.Cell):
- """ Attention """
- def __init__(
- self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_norm=False,
- attn_drop=0.,
- proj_drop=0.,
- norm_layer=nn.LayerNorm,
- ):
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(p=attn_drop)
- self.proj = nn.Dense(dim, dim)
- self.proj_drop = nn.Dropout(p=proj_drop)
-
- def construct(self, x):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- q, k = self.q_norm(q), self.k_norm(k)
-
- q = q * self.scale
- attn = q @ k.transpose(0, 1, 3, 2)
- attn = ms.ops.softmax(attn, axis=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
-
- x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
- class Block(nn.Cell):
- """ Block """
- def __init__(
- self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_norm=False,
- proj_drop=0.,
- attn_drop=0.,
- init_values=None,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- mlp_layer=Mlp,
- ):
- super().__init__()
- self.norm1 = norm_layer((dim, ))
- self.attn = Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- )
- self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-
- self.norm2 = norm_layer((dim, ))
- self.mlp = mlp_layer(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- )
- self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-
- def construct(self, x):
- x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
-
-
- class PatchEmbed(nn.Cell):
- """ 2D Image to Patch Embedding
- """
- output_fmt: Format
-
- def __init__(
- self,
- img_size: Optional[int] = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer: Optional[Callable] = None,
- flatten: bool = True,
- output_fmt: Optional[str] = None,
- bias: bool = True,
- ):
- super().__init__()
- self.patch_size = to_2tuple(patch_size)
- if img_size is not None:
- self.img_size = to_2tuple(img_size)
- self.grid_size = tuple(s // p for (s, p) in zip(self.img_size, self.patch_size))
- self.num_patches = self.grid_size[0] * self.grid_size[1]
- else:
- self.img_size = None
- self.grid_size = None
- self.num_patches = None
-
- if output_fmt is not None:
- self.flatten = False
- self.output_fmt = Format(output_fmt)
- else:
- # flatten spatial dim and transpose to channels last, kept for bwd compat
- self.flatten = flatten
- self.output_fmt = Format.NCHW
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=bias)
- self.norm = norm_layer((embed_dim, )) if norm_layer else nn.Identity()
-
- def construct(self, x):
- _, _, H, W = x.shape
- if self.img_size is not None:
- assert H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
- assert W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
-
- x = self.proj(x)
- if self.flatten:
- x = x.flatten(start_dim=2).transpose(0, 2, 1) # NCHW -> NLC
- elif self.output_fmt != Format.NCHW:
- x = nchw_to(x, self.output_fmt)
- x = self.norm(x)
- return x
-
-
- class VisionTransformer(nn.Cell):
- """ Vision Transformer
- """
-
- def __init__(
- self,
- img_size: Union[int, Tuple[int, int]] = 224,
- patch_size: Union[int, Tuple[int, int]] = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- init_values: Optional[float] = None,
- class_token: bool = True,
- no_embed_class: bool = False,
- pre_norm: bool = False,
- fc_norm: Optional[bool] = None,
- drop_rate: float = 0.,
- pos_drop_rate: float = 0.,
- patch_drop_rate: float = 0.,
- proj_drop_rate: float = 0.,
- attn_drop_rate: float = 0.,
- drop_path_rate: float = 0.,
- weight_init: str = '',
- embed_layer: Callable = PatchEmbed,
- norm_layer: Optional[Callable] = None,
- act_layer: Optional[Callable] = None,
- block_fn: Callable = Block,
- mlp_layer: Callable = Mlp,
- ):
- super().__init__()
- assert global_pool in ('', 'avg', 'token')
- assert class_token or global_pool != 'token'
- use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
- norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)
- act_layer = act_layer or nn.GELU
-
- self.num_classes = num_classes
- self.global_pool = global_pool
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
- self.num_prefix_tokens = 1 if class_token else 0
- self.no_embed_class = no_embed_class
- self.grad_checkpointing = False
-
- self.patch_embed = embed_layer(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
- )
- num_patches = self.patch_embed.num_patches
-
- self.cls_token = ms.Parameter(ms.ops.zeros((1, 1, embed_dim))) if class_token else None
- embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
- self.pos_embed = ms.Parameter(ms.ops.randn((1, embed_len, embed_dim)) * .02)
- self.pos_drop = nn.Dropout(p=pos_drop_rate)
- if patch_drop_rate > 0:
- self.patch_drop = PatchDropout(
- patch_drop_rate,
- num_prefix_tokens=self.num_prefix_tokens,
- )
- else:
- self.patch_drop = nn.Identity()
- self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
-
- dpr = list(ms.ops.linspace(0, drop_path_rate, depth)) # stochastic depth decay rule
- self.blocks = nn.SequentialCell(*[
- block_fn(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- init_values=init_values,
- proj_drop=proj_drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- act_layer=act_layer,
- mlp_layer=mlp_layer,
- )
- for i in range(depth)])
- self.norm = norm_layer((embed_dim, )) if not use_fc_norm else nn.Identity()
-
- # Classifier Head
- self.fc_norm = norm_layer((embed_dim, )) if use_fc_norm else nn.Identity()
- self.head_drop = nn.Dropout(p=drop_rate)
- self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
-
- self.pos_embed = initializer(TruncatedNormal(sigma=.02), self.pos_embed.shape, self.pos_embed.dtype)
- if self.cls_token is not None:
- self.cls_token = initializer(Normal(sigma=1e-6), self.cls_token.shape, self.cls_token.dtype)
- if weight_init != 'skip':
- self.apply(self.init_weights)
-
- def init_weights(self, cell):
- """ initialize weight """
- if isinstance(cell, nn.Dense):
- cell.weight.set_data(initializer(TruncatedNormal(sigma=.02), cell.weight.shape, cell.weight.dtype))
- if cell.bias is not None:
- cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
-
- def reset_classifier(self, num_classes: int, global_pool=None):
- """ reset classifier """
- self.num_classes = num_classes
- if global_pool is not None:
- assert global_pool in ('', 'avg', 'token')
- self.global_pool = global_pool
- self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
-
- def _pos_embed(self, x):
- """ position embedding """
- if self.no_embed_class:
- # deit-3, updated JAX (big vision)
- # position embedding does not overlap with class token, add then concat
- x = x + self.pos_embed
- if self.cls_token is not None:
- x = ms.ops.cat((self.cls_token, x), axis=1)
- else:
- # original timm, JAX, and deit vit impl
- # pos_embed has entry for class token, concat then add
- if self.cls_token is not None:
- x = ms.ops.cat((self.cls_token, x), axis=1)
- x = x + self.pos_embed
- return self.pos_drop(x)
-
- def _intermediate_layers(
- self,
- x: ms.Tensor,
- n=1,
- ):
- """ intermediate layers """
- outputs, num_blocks = [], len(self.blocks)
- take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
-
- # forward pass
- x = self.patch_embed(x)
- x = self._pos_embed(x)
- x = self.patch_drop(x)
- x = self.norm_pre(x)
- for i, blk in enumerate(self.blocks):
- x = blk(x)
- if i in take_indices:
- outputs.append(x)
-
- return outputs
-
- def get_intermediate_layers(
- self,
- x: ms.Tensor,
- n=1,
- reshape: bool = False,
- return_class_token: bool = False,
- norm: bool = False,
- ):
- """ Intermediate layer accessor
- Inspired by DINO / DINOv2 interface
- """
- # take last n blocks if n is an int, if in is a sequence, select by matching indices
- outputs = self._intermediate_layers(x, n)
- if norm:
- outputs = [self.norm(out) for out in outputs]
- class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
- outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
-
- if reshape:
- grid_size = self.patch_embed.grid_size
- outputs = [
- out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
- for out in outputs
- ]
-
- if return_class_token:
- return tuple(zip(outputs, class_tokens))
- return tuple(outputs)
-
- def construct_features(self, x):
- """ construct features """
- x = self.patch_embed(x)
- x = self._pos_embed(x)
- x = self.patch_drop(x)
- x = self.norm_pre(x)
- x = self.blocks(x)
- x = self.norm(x)
- return x
-
- def construct_head(self, x, pre_logits: bool = False):
- """ construct head """
- if self.global_pool:
- x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
- x = self.fc_norm(x)
- x = self.head_drop(x)
- return x if pre_logits else self.head(x)
-
- def construct(self, x):
- x = self.construct_features(x)
- x = self.construct_head(x)
- return x
-
-
- class DistilledVisionTransformer(VisionTransformer):
- """ Distilled Vision Transformer """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.dist_token = ms.Parameter(ms.ops.zeros((1, 1, self.embed_dim)))
- num_patches = self.patch_embed.num_patches
- self.pos_embed = ms.Parameter(ms.ops.zeros((1, num_patches + 2, self.embed_dim)))
- self.head_dist = nn.Dense(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
-
- self.dist_token = initializer(TruncatedNormal(sigma=.02), self.dist_token.shape, self.dist_token.dtype)
- self.pos_embed = initializer(TruncatedNormal(sigma=.02), self.pos_embed.shape, self.pos_embed.dtype)
- self.head_dist.apply(self._init_weights)
-
- def _init_weights(self, cell):
- """ initialize weight """
- if isinstance(cell, nn.Dense):
- cell.weight.set_data(initializer(TruncatedNormal(sigma=.02), cell.weight.shape, cell.weight.dtype))
- if cell.bias is not None:
- cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
- elif isinstance(cell, nn.LayerNorm):
- cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype))
- cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))
-
- def construct_features(self, x):
- """ construct features """
- x = self.patch_embed(x)
-
- x = ms.ops.cat((self.cls_token, self.dist_token, x), axis=1)
-
- x = x + self.pos_embed
- x = self.pos_drop(x)
-
- for blk in self.blocks:
- x = blk(x)
-
- x = self.norm(x)
- return x[:, 0], x[:, 1]
-
- def construct(self, x):
- x, x_dist = self.construct_features(x)
- x = self.head(x)
- x_dist = self.head_dist(x_dist)
- if self.training:
- return x, x_dist
- return (x + x_dist) / 2
-
-
- @register_model
- def deit_tiny_patch16_224(pretrained=False, **kwargs):
- """ deit-tiny-patch16 with image size 224 """
- default_cfg = _cfg()
- model = VisionTransformer(
- patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def deit_small_patch16_224(pretrained=False, **kwargs):
- """ deit-small-patch16 with image size 224 """
- default_cfg = _cfg()
- model = VisionTransformer(
- patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def deit_base_patch16_224(pretrained=False, **kwargs):
- """ deit-base-patch16 with image size 224 """
- default_cfg = _cfg()
- model = VisionTransformer(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
- """ deit-tiny-distilled-patch16 with image size 224 """
- default_cfg = _cfg()
- model = DistilledVisionTransformer(
- patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
- """ deit-small-distilled-patch16 with image size 224 """
- default_cfg = _cfg()
- model = DistilledVisionTransformer(
- patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
- """ deit-base-distilled-patch16 with image size 224 """
- default_cfg = _cfg()
- model = DistilledVisionTransformer(
- patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def deit_base_patch16_384(pretrained=False, **kwargs):
- """ deit-base-patch16 with image size 384 """
- default_cfg = _cfg()
- model = VisionTransformer(
- img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- @register_model
- def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
- """ deit-base-distilled-patch16 with image size 384 """
- default_cfg = _cfg()
- model = DistilledVisionTransformer(
- img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
- model.default_cfg = default_cfg
- if pretrained:
- load_pretrained(model, default_cfg)
-
- return model
-
-
- if __name__ == '__main__':
- dummy_input = ms.ops.randn((1, 3, 224, 224))
- net = deit_base_distilled_patch16_224()
- output = net(dummy_input)
- print(output.shape)
|