You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DeiT.py 23 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. # pylint: disable=E0401
  2. # pylint: disable=W0201
  3. """
  4. MindSpore implementation of `DeiT`.
  5. Refer to "Training data-efficient image transformers & distillation through attention"
  6. """
  7. from enum import Enum
  8. from functools import partial
  9. from typing import Union, Tuple, Optional, Callable
  10. import mindspore as ms
  11. from mindspore import nn
  12. from mindspore.common.initializer import initializer, TruncatedNormal, Normal
  13. from model.layers import DropPath, to_2tuple
  14. from model.registry import register_model
  15. from model.helper import load_pretrained
  16. __all__ = [
  17. 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
  18. 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
  19. 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
  20. 'deit_base_distilled_patch16_384',
  21. ]
  22. def _cfg(url='', **kwargs):
  23. return {
  24. 'url': url,
  25. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  26. 'crop_pct': .9, 'interpolation': 'bicubic',
  27. # 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  28. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  29. **kwargs
  30. }
  31. class Format(str, Enum):
  32. """ image format """
  33. NCHW = 'NCHW'
  34. NHWC = 'NHWC'
  35. NCL = 'NCL'
  36. NLC = 'NLC'
  37. def nchw_to(x: ms.Tensor, fmt: Format):
  38. """ switch image format """
  39. if fmt == Format.NHWC:
  40. x = x.permute(0, 2, 3, 1)
  41. elif fmt == Format.NLC:
  42. x = x.flatten(start_dim=2).transpose(0, 2, 1)
  43. elif fmt == Format.NCL:
  44. x = x.flatten(start_dim=2)
  45. return x
  46. class Mlp(nn.Cell):
  47. """ MLP as used in Vision Transformer, MLP-Mixer and related networks
  48. """
  49. def __init__(
  50. self,
  51. in_features,
  52. hidden_features=None,
  53. out_features=None,
  54. act_layer=nn.GELU,
  55. norm_layer=None,
  56. bias=True,
  57. drop=0.,
  58. use_conv=False,
  59. ):
  60. super().__init__()
  61. out_features = out_features or in_features
  62. hidden_features = hidden_features or in_features
  63. bias = to_2tuple(bias)
  64. drop_probs = to_2tuple(drop)
  65. linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Dense
  66. self.fc1 = linear_layer(in_features, hidden_features, has_bias=bias[0])
  67. self.act = act_layer()
  68. self.drop1 = nn.Dropout(p=drop_probs[0])
  69. self.norm = norm_layer((hidden_features, )) if norm_layer is not None else nn.Identity()
  70. self.fc2 = linear_layer(hidden_features, out_features, has_bias=bias[1])
  71. self.drop2 = nn.Dropout(p=drop_probs[1])
  72. def construct(self, x):
  73. x = self.fc1(x)
  74. x = self.act(x)
  75. x = self.drop1(x)
  76. x = self.fc2(x)
  77. x = self.drop2(x)
  78. return x
  79. class PatchDropout(nn.Cell):
  80. """ https://arxiv.org/abs/2212.00794
  81. """
  82. def __init__(
  83. self,
  84. prob: float = 0.5,
  85. num_prefix_tokens: int = 1,
  86. ordered: bool = False,
  87. return_indices: bool = False,
  88. ):
  89. super().__init__()
  90. assert 0 <= prob < 1.
  91. self.prob = prob
  92. self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
  93. self.ordered = ordered
  94. self.return_indices = return_indices
  95. def construct(self, x):
  96. if not self.training or self.prob == 0.:
  97. if self.return_indices:
  98. return x, None
  99. return x
  100. if self.num_prefix_tokens:
  101. prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
  102. else:
  103. prefix_tokens = None
  104. B = x.shape[0]
  105. L = x.shape[1]
  106. num_keep = max(1, int(L * (1. - self.prob)))
  107. keep_indices = ms.ops.argsort(ms.ops.randn((B, L)), axis=-1)[:, :num_keep]
  108. if self.ordered:
  109. # NOTE does not need to maintain patch order in typical transformer use,
  110. # but possibly useful for debug / visualization
  111. keep_indices = keep_indices.sort(dim=-1)[0]
  112. x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
  113. if prefix_tokens is not None:
  114. x = ms.ops.cat((prefix_tokens, x), axis=1)
  115. if self.return_indices:
  116. return x, keep_indices
  117. return x
  118. class LayerScale(nn.Cell):
  119. """ Layer Scale """
  120. def __init__(self, dim, init_values=1e-5):
  121. super().__init__()
  122. self.gamma = ms.Parameter(init_values * ms.ops.ones((dim, )))
  123. def construct(self, x):
  124. return x * self.gamma
  125. class Attention(nn.Cell):
  126. """ Attention """
  127. def __init__(
  128. self,
  129. dim,
  130. num_heads=8,
  131. qkv_bias=False,
  132. qk_norm=False,
  133. attn_drop=0.,
  134. proj_drop=0.,
  135. norm_layer=nn.LayerNorm,
  136. ):
  137. super().__init__()
  138. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  139. self.num_heads = num_heads
  140. self.head_dim = dim // num_heads
  141. self.scale = self.head_dim ** -0.5
  142. self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
  143. self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
  144. self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
  145. self.attn_drop = nn.Dropout(p=attn_drop)
  146. self.proj = nn.Dense(dim, dim)
  147. self.proj_drop = nn.Dropout(p=proj_drop)
  148. def construct(self, x):
  149. B, N, C = x.shape
  150. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  151. q, k, v = qkv.unbind(0)
  152. q, k = self.q_norm(q), self.k_norm(k)
  153. q = q * self.scale
  154. attn = q @ k.transpose(0, 1, 3, 2)
  155. attn = ms.ops.softmax(attn, axis=-1)
  156. attn = self.attn_drop(attn)
  157. x = attn @ v
  158. x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
  159. x = self.proj(x)
  160. x = self.proj_drop(x)
  161. return x
  162. class Block(nn.Cell):
  163. """ Block """
  164. def __init__(
  165. self,
  166. dim,
  167. num_heads,
  168. mlp_ratio=4.,
  169. qkv_bias=False,
  170. qk_norm=False,
  171. proj_drop=0.,
  172. attn_drop=0.,
  173. init_values=None,
  174. drop_path=0.,
  175. act_layer=nn.GELU,
  176. norm_layer=nn.LayerNorm,
  177. mlp_layer=Mlp,
  178. ):
  179. super().__init__()
  180. self.norm1 = norm_layer((dim, ))
  181. self.attn = Attention(
  182. dim,
  183. num_heads=num_heads,
  184. qkv_bias=qkv_bias,
  185. qk_norm=qk_norm,
  186. attn_drop=attn_drop,
  187. proj_drop=proj_drop,
  188. norm_layer=norm_layer,
  189. )
  190. self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
  191. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  192. self.norm2 = norm_layer((dim, ))
  193. self.mlp = mlp_layer(
  194. in_features=dim,
  195. hidden_features=int(dim * mlp_ratio),
  196. act_layer=act_layer,
  197. drop=proj_drop,
  198. )
  199. self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
  200. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  201. def construct(self, x):
  202. x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
  203. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  204. return x
  205. class PatchEmbed(nn.Cell):
  206. """ 2D Image to Patch Embedding
  207. """
  208. output_fmt: Format
  209. def __init__(
  210. self,
  211. img_size: Optional[int] = 224,
  212. patch_size: int = 16,
  213. in_chans: int = 3,
  214. embed_dim: int = 768,
  215. norm_layer: Optional[Callable] = None,
  216. flatten: bool = True,
  217. output_fmt: Optional[str] = None,
  218. bias: bool = True,
  219. ):
  220. super().__init__()
  221. self.patch_size = to_2tuple(patch_size)
  222. if img_size is not None:
  223. self.img_size = to_2tuple(img_size)
  224. self.grid_size = tuple(s // p for (s, p) in zip(self.img_size, self.patch_size))
  225. self.num_patches = self.grid_size[0] * self.grid_size[1]
  226. else:
  227. self.img_size = None
  228. self.grid_size = None
  229. self.num_patches = None
  230. if output_fmt is not None:
  231. self.flatten = False
  232. self.output_fmt = Format(output_fmt)
  233. else:
  234. # flatten spatial dim and transpose to channels last, kept for bwd compat
  235. self.flatten = flatten
  236. self.output_fmt = Format.NCHW
  237. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=bias)
  238. self.norm = norm_layer((embed_dim, )) if norm_layer else nn.Identity()
  239. def construct(self, x):
  240. _, _, H, W = x.shape
  241. if self.img_size is not None:
  242. assert H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
  243. assert W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
  244. x = self.proj(x)
  245. if self.flatten:
  246. x = x.flatten(start_dim=2).transpose(0, 2, 1) # NCHW -> NLC
  247. elif self.output_fmt != Format.NCHW:
  248. x = nchw_to(x, self.output_fmt)
  249. x = self.norm(x)
  250. return x
  251. class VisionTransformer(nn.Cell):
  252. """ Vision Transformer
  253. """
  254. def __init__(
  255. self,
  256. img_size: Union[int, Tuple[int, int]] = 224,
  257. patch_size: Union[int, Tuple[int, int]] = 16,
  258. in_chans: int = 3,
  259. num_classes: int = 1000,
  260. global_pool: str = 'token',
  261. embed_dim: int = 768,
  262. depth: int = 12,
  263. num_heads: int = 12,
  264. mlp_ratio: float = 4.,
  265. qkv_bias: bool = True,
  266. qk_norm: bool = False,
  267. init_values: Optional[float] = None,
  268. class_token: bool = True,
  269. no_embed_class: bool = False,
  270. pre_norm: bool = False,
  271. fc_norm: Optional[bool] = None,
  272. drop_rate: float = 0.,
  273. pos_drop_rate: float = 0.,
  274. patch_drop_rate: float = 0.,
  275. proj_drop_rate: float = 0.,
  276. attn_drop_rate: float = 0.,
  277. drop_path_rate: float = 0.,
  278. weight_init: str = '',
  279. embed_layer: Callable = PatchEmbed,
  280. norm_layer: Optional[Callable] = None,
  281. act_layer: Optional[Callable] = None,
  282. block_fn: Callable = Block,
  283. mlp_layer: Callable = Mlp,
  284. ):
  285. super().__init__()
  286. assert global_pool in ('', 'avg', 'token')
  287. assert class_token or global_pool != 'token'
  288. use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
  289. norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)
  290. act_layer = act_layer or nn.GELU
  291. self.num_classes = num_classes
  292. self.global_pool = global_pool
  293. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  294. self.num_prefix_tokens = 1 if class_token else 0
  295. self.no_embed_class = no_embed_class
  296. self.grad_checkpointing = False
  297. self.patch_embed = embed_layer(
  298. img_size=img_size,
  299. patch_size=patch_size,
  300. in_chans=in_chans,
  301. embed_dim=embed_dim,
  302. bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
  303. )
  304. num_patches = self.patch_embed.num_patches
  305. self.cls_token = ms.Parameter(ms.ops.zeros((1, 1, embed_dim))) if class_token else None
  306. embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
  307. self.pos_embed = ms.Parameter(ms.ops.randn((1, embed_len, embed_dim)) * .02)
  308. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  309. if patch_drop_rate > 0:
  310. self.patch_drop = PatchDropout(
  311. patch_drop_rate,
  312. num_prefix_tokens=self.num_prefix_tokens,
  313. )
  314. else:
  315. self.patch_drop = nn.Identity()
  316. self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
  317. dpr = list(ms.ops.linspace(0, drop_path_rate, depth)) # stochastic depth decay rule
  318. self.blocks = nn.SequentialCell(*[
  319. block_fn(
  320. dim=embed_dim,
  321. num_heads=num_heads,
  322. mlp_ratio=mlp_ratio,
  323. qkv_bias=qkv_bias,
  324. qk_norm=qk_norm,
  325. init_values=init_values,
  326. proj_drop=proj_drop_rate,
  327. attn_drop=attn_drop_rate,
  328. drop_path=dpr[i],
  329. norm_layer=norm_layer,
  330. act_layer=act_layer,
  331. mlp_layer=mlp_layer,
  332. )
  333. for i in range(depth)])
  334. self.norm = norm_layer((embed_dim, )) if not use_fc_norm else nn.Identity()
  335. # Classifier Head
  336. self.fc_norm = norm_layer((embed_dim, )) if use_fc_norm else nn.Identity()
  337. self.head_drop = nn.Dropout(p=drop_rate)
  338. self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  339. self.pos_embed = initializer(TruncatedNormal(sigma=.02), self.pos_embed.shape, self.pos_embed.dtype)
  340. if self.cls_token is not None:
  341. self.cls_token = initializer(Normal(sigma=1e-6), self.cls_token.shape, self.cls_token.dtype)
  342. if weight_init != 'skip':
  343. self.apply(self.init_weights)
  344. def init_weights(self, cell):
  345. """ initialize weight """
  346. if isinstance(cell, nn.Dense):
  347. cell.weight.set_data(initializer(TruncatedNormal(sigma=.02), cell.weight.shape, cell.weight.dtype))
  348. if cell.bias is not None:
  349. cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
  350. def reset_classifier(self, num_classes: int, global_pool=None):
  351. """ reset classifier """
  352. self.num_classes = num_classes
  353. if global_pool is not None:
  354. assert global_pool in ('', 'avg', 'token')
  355. self.global_pool = global_pool
  356. self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  357. def _pos_embed(self, x):
  358. """ position embedding """
  359. if self.no_embed_class:
  360. # deit-3, updated JAX (big vision)
  361. # position embedding does not overlap with class token, add then concat
  362. x = x + self.pos_embed
  363. if self.cls_token is not None:
  364. x = ms.ops.cat((self.cls_token, x), axis=1)
  365. else:
  366. # original timm, JAX, and deit vit impl
  367. # pos_embed has entry for class token, concat then add
  368. if self.cls_token is not None:
  369. x = ms.ops.cat((self.cls_token, x), axis=1)
  370. x = x + self.pos_embed
  371. return self.pos_drop(x)
  372. def _intermediate_layers(
  373. self,
  374. x: ms.Tensor,
  375. n=1,
  376. ):
  377. """ intermediate layers """
  378. outputs, num_blocks = [], len(self.blocks)
  379. take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
  380. # forward pass
  381. x = self.patch_embed(x)
  382. x = self._pos_embed(x)
  383. x = self.patch_drop(x)
  384. x = self.norm_pre(x)
  385. for i, blk in enumerate(self.blocks):
  386. x = blk(x)
  387. if i in take_indices:
  388. outputs.append(x)
  389. return outputs
  390. def get_intermediate_layers(
  391. self,
  392. x: ms.Tensor,
  393. n=1,
  394. reshape: bool = False,
  395. return_class_token: bool = False,
  396. norm: bool = False,
  397. ):
  398. """ Intermediate layer accessor
  399. Inspired by DINO / DINOv2 interface
  400. """
  401. # take last n blocks if n is an int, if in is a sequence, select by matching indices
  402. outputs = self._intermediate_layers(x, n)
  403. if norm:
  404. outputs = [self.norm(out) for out in outputs]
  405. class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
  406. outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
  407. if reshape:
  408. grid_size = self.patch_embed.grid_size
  409. outputs = [
  410. out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
  411. for out in outputs
  412. ]
  413. if return_class_token:
  414. return tuple(zip(outputs, class_tokens))
  415. return tuple(outputs)
  416. def construct_features(self, x):
  417. """ construct features """
  418. x = self.patch_embed(x)
  419. x = self._pos_embed(x)
  420. x = self.patch_drop(x)
  421. x = self.norm_pre(x)
  422. x = self.blocks(x)
  423. x = self.norm(x)
  424. return x
  425. def construct_head(self, x, pre_logits: bool = False):
  426. """ construct head """
  427. if self.global_pool:
  428. x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  429. x = self.fc_norm(x)
  430. x = self.head_drop(x)
  431. return x if pre_logits else self.head(x)
  432. def construct(self, x):
  433. x = self.construct_features(x)
  434. x = self.construct_head(x)
  435. return x
  436. class DistilledVisionTransformer(VisionTransformer):
  437. """ Distilled Vision Transformer """
  438. def __init__(self, *args, **kwargs):
  439. super().__init__(*args, **kwargs)
  440. self.dist_token = ms.Parameter(ms.ops.zeros((1, 1, self.embed_dim)))
  441. num_patches = self.patch_embed.num_patches
  442. self.pos_embed = ms.Parameter(ms.ops.zeros((1, num_patches + 2, self.embed_dim)))
  443. self.head_dist = nn.Dense(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
  444. self.dist_token = initializer(TruncatedNormal(sigma=.02), self.dist_token.shape, self.dist_token.dtype)
  445. self.pos_embed = initializer(TruncatedNormal(sigma=.02), self.pos_embed.shape, self.pos_embed.dtype)
  446. self.head_dist.apply(self._init_weights)
  447. def _init_weights(self, cell):
  448. """ initialize weight """
  449. if isinstance(cell, nn.Dense):
  450. cell.weight.set_data(initializer(TruncatedNormal(sigma=.02), cell.weight.shape, cell.weight.dtype))
  451. if cell.bias is not None:
  452. cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
  453. elif isinstance(cell, nn.LayerNorm):
  454. cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype))
  455. cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))
  456. def construct_features(self, x):
  457. """ construct features """
  458. x = self.patch_embed(x)
  459. x = ms.ops.cat((self.cls_token, self.dist_token, x), axis=1)
  460. x = x + self.pos_embed
  461. x = self.pos_drop(x)
  462. for blk in self.blocks:
  463. x = blk(x)
  464. x = self.norm(x)
  465. return x[:, 0], x[:, 1]
  466. def construct(self, x):
  467. x, x_dist = self.construct_features(x)
  468. x = self.head(x)
  469. x_dist = self.head_dist(x_dist)
  470. if self.training:
  471. return x, x_dist
  472. return (x + x_dist) / 2
  473. @register_model
  474. def deit_tiny_patch16_224(pretrained=False, **kwargs):
  475. """ deit-tiny-patch16 with image size 224 """
  476. default_cfg = _cfg()
  477. model = VisionTransformer(
  478. patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
  479. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  480. model.default_cfg = default_cfg
  481. if pretrained:
  482. load_pretrained(model, default_cfg)
  483. return model
  484. @register_model
  485. def deit_small_patch16_224(pretrained=False, **kwargs):
  486. """ deit-small-patch16 with image size 224 """
  487. default_cfg = _cfg()
  488. model = VisionTransformer(
  489. patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
  490. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  491. model.default_cfg = default_cfg
  492. if pretrained:
  493. load_pretrained(model, default_cfg)
  494. return model
  495. @register_model
  496. def deit_base_patch16_224(pretrained=False, **kwargs):
  497. """ deit-base-patch16 with image size 224 """
  498. default_cfg = _cfg()
  499. model = VisionTransformer(
  500. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
  501. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  502. model.default_cfg = default_cfg
  503. if pretrained:
  504. load_pretrained(model, default_cfg)
  505. return model
  506. @register_model
  507. def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
  508. """ deit-tiny-distilled-patch16 with image size 224 """
  509. default_cfg = _cfg()
  510. model = DistilledVisionTransformer(
  511. patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
  512. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  513. model.default_cfg = default_cfg
  514. if pretrained:
  515. load_pretrained(model, default_cfg)
  516. return model
  517. @register_model
  518. def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
  519. """ deit-small-distilled-patch16 with image size 224 """
  520. default_cfg = _cfg()
  521. model = DistilledVisionTransformer(
  522. patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
  523. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  524. model.default_cfg = default_cfg
  525. if pretrained:
  526. load_pretrained(model, default_cfg)
  527. return model
  528. @register_model
  529. def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
  530. """ deit-base-distilled-patch16 with image size 224 """
  531. default_cfg = _cfg()
  532. model = DistilledVisionTransformer(
  533. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
  534. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  535. model.default_cfg = default_cfg
  536. if pretrained:
  537. load_pretrained(model, default_cfg)
  538. return model
  539. @register_model
  540. def deit_base_patch16_384(pretrained=False, **kwargs):
  541. """ deit-base-patch16 with image size 384 """
  542. default_cfg = _cfg()
  543. model = VisionTransformer(
  544. img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
  545. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  546. model.default_cfg = default_cfg
  547. if pretrained:
  548. load_pretrained(model, default_cfg)
  549. return model
  550. @register_model
  551. def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
  552. """ deit-base-distilled-patch16 with image size 384 """
  553. default_cfg = _cfg()
  554. model = DistilledVisionTransformer(
  555. img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
  556. norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
  557. model.default_cfg = default_cfg
  558. if pretrained:
  559. load_pretrained(model, default_cfg)
  560. return model
  561. if __name__ == '__main__':
  562. dummy_input = ms.ops.randn((1, 3, 224, 224))
  563. net = deit_base_distilled_patch16_224()
  564. output = net(dummy_input)
  565. print(output.shape)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN