You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. from typing import Optional, Tuple, List
  2. import warnings
  3. import math
  4. import jittor as jt
  5. from jittor import Var
  6. from jittor.nn import Module, Linear, softmax, pad, linear, dropout
  7. from jittor.init import xavier_uniform_, xavier_gauss_, constant_
  8. def _canonical_mask(
  9. mask: Optional[Var],
  10. mask_name: str,
  11. other_type,
  12. other_name: str,
  13. target_type,
  14. check_other: bool = True,
  15. ) -> Optional[Var]:
  16. if mask is not None:
  17. _mask_dtype = mask.dtype
  18. _mask_is_float = mask.dtype == jt.float16 or mask.dtype == jt.float32 or mask.dtype == jt.float64
  19. if _mask_dtype != jt.bool and not _mask_is_float:
  20. raise AssertionError(
  21. f"only bool and floating types of {mask_name} are supported")
  22. if check_other and other_type is not None:
  23. if _mask_dtype != other_type:
  24. warnings.warn(
  25. f"Support for mismatched {mask_name} and {other_name} "
  26. "is deprecated. Use same type for both instead.")
  27. if not _mask_is_float:
  28. # WARNING(514flowey): Check Here
  29. new_mask = jt.zeros_like(mask, dtype=target_type)
  30. new_mask[mask] = float("-inf")
  31. mask = new_mask
  32. return mask
  33. def _none_or_dtype(input: Optional[Var]):
  34. if input is None:
  35. return None
  36. elif isinstance(input, jt.Var):
  37. return input.dtype
  38. def baddbmm(input_var: jt.Var,
  39. batch1: jt.Var,
  40. batch2: jt.Var,
  41. beta=1,
  42. alpha=1) -> jt.Var:
  43. # WARNING(514flowey): Check here
  44. return beta * input_var + alpha * (batch1 @ batch2)
  45. def scaled_dot_product_attention(query,
  46. key,
  47. value,
  48. attn_mask=None,
  49. dropout_p=0.0,
  50. is_causal=False,
  51. scale=None,
  52. training=True) -> jt.Var:
  53. # Efficient implementation equivalent to the following:
  54. L, S = query.size(-2), key.size(-2)
  55. scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
  56. attn_bias = jt.zeros(L, S, dtype=query.dtype)
  57. if is_causal:
  58. assert attn_mask is None
  59. temp_mask = jt.ones(L, S, dtype=jt.bool).tril(diagonal=0)
  60. attn_bias[jt.logical_not(temp_mask)] = float("-inf")
  61. # attn_bias.to(query.dtype)
  62. attn_bias = jt.array(attn_bias, query.dtype)
  63. if attn_mask is not None:
  64. if attn_mask.dtype == jt.bool:
  65. attn_bias[jt.logical_not(temp_mask)] = float("-inf")
  66. else:
  67. attn_bias += attn_mask
  68. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  69. attn_weight += attn_bias
  70. attn_weight = softmax(attn_weight, dim=-1)
  71. attn_weight = dropout(attn_weight, dropout_p, is_train=training)
  72. return attn_weight @ value
  73. def _mha_shape_check(query: Var, key: Var, value: Var,
  74. key_padding_mask: Optional[Var], attn_mask: Optional[Var],
  75. num_heads: int):
  76. if query.dim() == 3:
  77. is_batched = True
  78. assert key.dim() == 3 and value.dim() == 3, \
  79. ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
  80. f" but found {key.dim()}-D and {value.dim()}-D Vars respectively")
  81. if key_padding_mask is not None:
  82. assert key_padding_mask.dim() == 2, \
  83. ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
  84. f" but found {key_padding_mask.dim()}-D Var instead")
  85. if attn_mask is not None:
  86. assert attn_mask.dim() in (2, 3), \
  87. ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
  88. f" but found {attn_mask.dim()}-D Var instead")
  89. elif query.dim() == 2:
  90. is_batched = False
  91. assert key.dim() == 2 and value.dim() == 2, \
  92. ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
  93. f" but found {key.dim()}-D and {value.dim()}-D Vars respectively")
  94. if key_padding_mask is not None:
  95. assert key_padding_mask.dim() == 1, \
  96. ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
  97. f" but found {key_padding_mask.dim()}-D Var instead")
  98. if attn_mask is not None:
  99. assert attn_mask.dim() in (2, 3), \
  100. ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
  101. f" but found {attn_mask.dim()}-D Var instead")
  102. if attn_mask.dim() == 3:
  103. expected_shape = (num_heads, query.shape[0], key.shape[0])
  104. assert attn_mask.shape == expected_shape, \
  105. (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
  106. else:
  107. raise AssertionError(
  108. f"query should be unbatched 2D or batched 3D Var but received {query.dim()}-D query Var"
  109. )
  110. return is_batched
  111. def _in_projection_packed(
  112. q: Var,
  113. k: Var,
  114. v: Var,
  115. w: Var,
  116. b: Optional[Var] = None,
  117. ) -> List[Var]:
  118. E = q.size(-1)
  119. if k is v:
  120. if q is k:
  121. # self-attention
  122. proj = linear(q, w, b)
  123. # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
  124. # proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
  125. nshape = proj.shape[:-1] + (3, E)
  126. proj = proj.reshape(nshape).unsqueeze(0).transpose(0,
  127. -2).squeeze(-2)
  128. return proj[0], proj[1], proj[2]
  129. else:
  130. # encoder-decoder attention
  131. w_q, w_kv = w.split([E, E * 2])
  132. if b is None:
  133. b_q = b_kv = None
  134. else:
  135. b_q, b_kv = b.split([E, E * 2])
  136. q_proj = linear(q, w_q, b_q)
  137. kv_proj = linear(k, w_kv, b_kv)
  138. # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
  139. # kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
  140. nshape = kv_proj.shape[:-1] + (2, E)
  141. kv_proj = kv_proj.reshape(nshape).unsqueeze(0).transpose(
  142. 0, -2).squeeze(-2)
  143. return (q_proj, kv_proj[0], kv_proj[1])
  144. else:
  145. w_q, w_k, w_v = w.chunk(3)
  146. if b is None:
  147. b_q = b_k = b_v = None
  148. else:
  149. b_q, b_k, b_v = b.chunk(3)
  150. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
  151. def _in_projection(
  152. q: Var,
  153. k: Var,
  154. v: Var,
  155. w_q: Var,
  156. w_k: Var,
  157. w_v: Var,
  158. b_q: Optional[Var] = None,
  159. b_k: Optional[Var] = None,
  160. b_v: Optional[Var] = None,
  161. ) -> Tuple[Var, Var, Var]:
  162. Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
  163. assert w_q.shape == (
  164. Eq, Eq
  165. ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
  166. assert w_k.shape == (
  167. Eq,
  168. Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
  169. assert w_v.shape == (
  170. Eq, Ev
  171. ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
  172. assert b_q is None or b_q.shape == (
  173. Eq, ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
  174. assert b_k is None or b_k.shape == (
  175. Eq, ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
  176. assert b_v is None or b_v.shape == (
  177. Eq, ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
  178. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
  179. def multi_head_attention_forward(
  180. query: Var,
  181. key: Var,
  182. value: Var,
  183. embed_dim_to_check: int,
  184. num_heads: int,
  185. in_proj_weight: Optional[Var],
  186. in_proj_bias: Optional[Var],
  187. bias_k: Optional[Var],
  188. bias_v: Optional[Var],
  189. add_zero_attn: bool,
  190. dropout_p: float,
  191. out_proj_weight: Var,
  192. out_proj_bias: Optional[Var],
  193. training: bool = True,
  194. key_padding_mask: Optional[Var] = None,
  195. need_weights: bool = True,
  196. attn_mask: Optional[Var] = None,
  197. use_separate_proj_weight: bool = False,
  198. q_proj_weight: Optional[Var] = None,
  199. k_proj_weight: Optional[Var] = None,
  200. v_proj_weight: Optional[Var] = None,
  201. static_k: Optional[Var] = None,
  202. static_v: Optional[Var] = None,
  203. average_attn_weights: bool = True,
  204. is_causal: bool = False,
  205. ) -> Tuple[Var, Optional[Var]]:
  206. is_batched = _mha_shape_check(query, key, value, key_padding_mask,
  207. attn_mask, num_heads)
  208. # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
  209. # is batched, run the computation and before returning squeeze the
  210. # batch dimension so that the output doesn't carry this temporary batch dimension.
  211. if not is_batched:
  212. # unsqueeze if the input is unbatched
  213. query = query.unsqueeze(1)
  214. key = key.unsqueeze(1)
  215. value = value.unsqueeze(1)
  216. if key_padding_mask is not None:
  217. key_padding_mask = key_padding_mask.unsqueeze(0)
  218. # set up shape vars
  219. tgt_len, bsz, embed_dim = query.shape
  220. src_len, _, _ = key.shape
  221. key_padding_mask = _canonical_mask(mask=key_padding_mask,
  222. mask_name="key_padding_mask",
  223. other_type=_none_or_dtype(attn_mask),
  224. other_name="attn_mask",
  225. target_type=query.dtype)
  226. if is_causal and attn_mask is None:
  227. raise RuntimeError(
  228. "Need attn_mask if specifying the is_causal hint. "
  229. "You may use the Transformer module method "
  230. "`generate_square_subsequent_mask` to create this mask.")
  231. if is_causal and key_padding_mask is None and not need_weights:
  232. # when we have a kpm or need weights, we need attn_mask
  233. # Otherwise, we use the is_causal hint go as is_causal
  234. # indicator to SDPA.
  235. attn_mask = None
  236. else:
  237. attn_mask = _canonical_mask(
  238. mask=attn_mask,
  239. mask_name="attn_mask",
  240. other_type=None,
  241. other_name="",
  242. target_type=query.dtype,
  243. check_other=False,
  244. )
  245. if key_padding_mask is not None:
  246. # We have the attn_mask, and use that to merge kpm into it.
  247. # Turn off use of is_causal hint, as the merged mask is no
  248. # longer causal.
  249. is_causal = False
  250. assert embed_dim == embed_dim_to_check, \
  251. f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
  252. if isinstance(embed_dim, jt.Var):
  253. # embed_dim can be a Var when JIT tracing
  254. head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
  255. else:
  256. head_dim = embed_dim // num_heads
  257. assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
  258. if use_separate_proj_weight:
  259. # allow MHA to have different embedding dimensions when separate projection weights are used
  260. assert key.shape[:2] == value.shape[:2], \
  261. f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
  262. else:
  263. assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
  264. #
  265. # compute in-projection
  266. #
  267. if not use_separate_proj_weight:
  268. assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
  269. q, k, v = _in_projection_packed(query, key, value, in_proj_weight,
  270. in_proj_bias)
  271. else:
  272. assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
  273. assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
  274. assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
  275. if in_proj_bias is None:
  276. b_q = b_k = b_v = None
  277. else:
  278. b_q, b_k, b_v = in_proj_bias.chunk(3)
  279. q, k, v = _in_projection(query, key, value, q_proj_weight,
  280. k_proj_weight, v_proj_weight, b_q, b_k, b_v)
  281. # prep attention mask
  282. if attn_mask is not None:
  283. # ensure attn_mask's dim is 3
  284. if attn_mask.dim() == 2:
  285. correct_2d_size = (tgt_len, src_len)
  286. if attn_mask.shape != correct_2d_size:
  287. raise RuntimeError(
  288. f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
  289. )
  290. attn_mask = attn_mask.unsqueeze(0)
  291. elif attn_mask.dim() == 3:
  292. correct_3d_size = (bsz * num_heads, tgt_len, src_len)
  293. if attn_mask.shape != correct_3d_size:
  294. raise RuntimeError(
  295. f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
  296. )
  297. else:
  298. raise RuntimeError(
  299. f"attn_mask's dimension {attn_mask.dim()} is not supported")
  300. # add bias along batch dimension (currently second)
  301. if bias_k is not None and bias_v is not None:
  302. assert static_k is None, "bias cannot be added to static key."
  303. assert static_v is None, "bias cannot be added to static value."
  304. k = jt.concat([k, bias_k.repeat(1, bsz, 1)])
  305. v = jt.concat([v, bias_v.repeat(1, bsz, 1)])
  306. if attn_mask is not None:
  307. attn_mask = pad(attn_mask, (0, 1))
  308. if key_padding_mask is not None:
  309. key_padding_mask = pad(key_padding_mask, (0, 1))
  310. else:
  311. assert bias_k is None
  312. assert bias_v is None
  313. #
  314. # reshape q, k, v for multihead attention and make em batch first
  315. #
  316. q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
  317. if static_k is None:
  318. k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
  319. else:
  320. # TODO finish disentangling control flow so we don't do in-projections when statics are passed
  321. assert static_k.size(0) == bsz * num_heads, \
  322. f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
  323. assert static_k.size(2) == head_dim, \
  324. f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
  325. k = static_k
  326. if static_v is None:
  327. v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
  328. else:
  329. # TODO finish disentangling control flow so we don't do in-projections when statics are passed
  330. assert static_v.size(0) == bsz * num_heads, \
  331. f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
  332. assert static_v.size(2) == head_dim, \
  333. f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
  334. v = static_v
  335. # add zero attention along batch dimension (now first)
  336. if add_zero_attn:
  337. zero_attn_shape = (bsz * num_heads, 1, head_dim)
  338. k = jt.concat([k, jt.zeros(zero_attn_shape, dtype=k.dtype)], dim=1)
  339. v = jt.concat([v, jt.zeros(zero_attn_shape, dtype=v.dtype)], dim=1)
  340. if attn_mask is not None:
  341. attn_mask = pad(attn_mask, (0, 1))
  342. if key_padding_mask is not None:
  343. key_padding_mask = pad(key_padding_mask, (0, 1))
  344. # update source sequence length after adjustments
  345. src_len = k.size(1)
  346. # merge key padding and attention masks
  347. if key_padding_mask is not None:
  348. assert key_padding_mask.shape == (bsz, src_len), \
  349. f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
  350. key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
  351. expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
  352. if attn_mask is None:
  353. attn_mask = key_padding_mask
  354. else:
  355. attn_mask = attn_mask + key_padding_mask
  356. # adjust dropout probability
  357. if not training:
  358. dropout_p = 0.0
  359. #
  360. # (deep breath) calculate attention and out projection
  361. #
  362. if need_weights:
  363. B, Nt, E = q.shape
  364. q_scaled = q / math.sqrt(E)
  365. assert not (is_causal and attn_mask is None
  366. ), "FIXME: is_causal not implemented for need_weights"
  367. if attn_mask is not None:
  368. attn_output_weights = baddbmm(attn_mask, q_scaled,
  369. k.transpose(-2, -1))
  370. else:
  371. attn_output_weights = jt.bmm(q_scaled, k.transpose(-2, -1))
  372. attn_output_weights = softmax(attn_output_weights, dim=-1)
  373. if dropout_p > 0.0:
  374. attn_output_weights = dropout(attn_output_weights, p=dropout_p)
  375. attn_output = jt.bmm(attn_output_weights, v)
  376. attn_output = attn_output.transpose(0, 1).contiguous().view(
  377. tgt_len * bsz, embed_dim)
  378. attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  379. attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
  380. # optionally average attention weights over heads
  381. attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
  382. src_len)
  383. if average_attn_weights:
  384. attn_output_weights = attn_output_weights.mean(dim=1)
  385. if not is_batched:
  386. # squeeze the output if input was unbatched
  387. attn_output = attn_output.squeeze(1)
  388. attn_output_weights = attn_output_weights.squeeze(0)
  389. return attn_output, attn_output_weights
  390. else:
  391. # attn_mask can be either (L,S) or (N*num_heads, L, S)
  392. # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
  393. # in order to match the input for SDPA of (N, num_heads, L, S)
  394. if attn_mask is not None:
  395. if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
  396. attn_mask = attn_mask.unsqueeze(0)
  397. else:
  398. attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
  399. q = q.view(bsz, num_heads, tgt_len, head_dim)
  400. k = k.view(bsz, num_heads, src_len, head_dim)
  401. v = v.view(bsz, num_heads, src_len, head_dim)
  402. attn_output = scaled_dot_product_attention(q,
  403. k,
  404. v,
  405. attn_mask,
  406. dropout_p,
  407. is_causal,
  408. training=training)
  409. attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(
  410. bsz * tgt_len, embed_dim)
  411. attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  412. attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
  413. if not is_batched:
  414. # squeeze the output if input was unbatched
  415. attn_output = attn_output.squeeze(1)
  416. return attn_output, None
  417. class MultiheadAttention(Module):
  418. __constants__ = ['batch_first']
  419. bias_k: Optional[jt.Var]
  420. bias_v: Optional[jt.Var]
  421. def __init__(self,
  422. embed_dim,
  423. num_heads,
  424. dropout=0.,
  425. bias=True,
  426. add_bias_kv=False,
  427. add_zero_attn=False,
  428. kdim=None,
  429. vdim=None,
  430. batch_first=False,
  431. dtype=jt.float32) -> None:
  432. if embed_dim <= 0 or num_heads <= 0:
  433. raise ValueError(
  434. f"embed_dim and num_heads must be greater than 0,"
  435. f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
  436. )
  437. factory_kwargs = {'dtype': dtype}
  438. super().__init__()
  439. self.embed_dim = embed_dim
  440. self.kdim = kdim if kdim is not None else embed_dim
  441. self.vdim = vdim if vdim is not None else embed_dim
  442. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  443. self.num_heads = num_heads
  444. self.dropout = dropout
  445. self.batch_first = batch_first
  446. self.head_dim = embed_dim // num_heads
  447. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  448. if not self._qkv_same_embed_dim:
  449. self.q_proj_weight = jt.empty((embed_dim, embed_dim),
  450. **factory_kwargs)
  451. self.k_proj_weight = jt.empty((embed_dim, self.kdim),
  452. **factory_kwargs)
  453. self.v_proj_weight = jt.empty((embed_dim, self.vdim),
  454. **factory_kwargs)
  455. self.in_proj_weight = None
  456. else:
  457. self.q_proj_weight = None
  458. self.k_proj_weight = None
  459. self.v_proj_weight = None
  460. self.in_proj_weight = jt.empty((3 * embed_dim, embed_dim),
  461. **factory_kwargs)
  462. if bias:
  463. self.in_proj_bias = jt.empty(3 * embed_dim, **factory_kwargs)
  464. else:
  465. self.in_proj_bias = None
  466. self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
  467. if add_bias_kv:
  468. self.bias_k = jt.empty((1, 1, embed_dim), **factory_kwargs)
  469. self.bias_v = jt.empty((1, 1, embed_dim), **factory_kwargs)
  470. else:
  471. self.bias_k = self.bias_v = None
  472. self.add_zero_attn = add_zero_attn
  473. self._reset_parameters()
  474. def _reset_parameters(self):
  475. if self._qkv_same_embed_dim:
  476. xavier_uniform_(self.in_proj_weight)
  477. else:
  478. xavier_uniform_(self.q_proj_weight)
  479. xavier_uniform_(self.k_proj_weight)
  480. xavier_uniform_(self.v_proj_weight)
  481. if self.in_proj_bias is not None:
  482. constant_(self.in_proj_bias, 0.)
  483. constant_(self.out_proj.bias, 0.)
  484. if self.bias_k is not None:
  485. xavier_gauss_(self.bias_k)
  486. if self.bias_v is not None:
  487. xavier_gauss_(self.bias_v)
  488. def __setstate__(self, state):
  489. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  490. if '_qkv_same_embed_dim' not in state:
  491. state['_qkv_same_embed_dim'] = True
  492. super().__setstate__(state)
  493. def execute(self,
  494. query: Var,
  495. key: Var,
  496. value: Var,
  497. key_padding_mask: Optional[Var] = None,
  498. need_weights: bool = True,
  499. attn_mask: Optional[Var] = None,
  500. average_attn_weights: bool = True,
  501. is_causal: bool = False) -> Tuple[Var, Optional[Var]]:
  502. #####
  503. # Fast Path is not Supported.
  504. #####
  505. is_batched = query.dim() == 3
  506. key_padding_mask = _canonical_mask(
  507. mask=key_padding_mask,
  508. mask_name="key_padding_mask",
  509. other_type=_none_or_dtype(attn_mask),
  510. other_name="attn_mask",
  511. target_type=query.dtype)
  512. attn_mask = _canonical_mask(
  513. mask=attn_mask,
  514. mask_name="attn_mask",
  515. other_type=None,
  516. other_name="",
  517. target_type=query.dtype,
  518. check_other=False,
  519. )
  520. if self.batch_first and is_batched:
  521. # make sure that the transpose op does not affect the "is" property
  522. if key is value:
  523. if query is key:
  524. query = key = value = query.transpose(1, 0)
  525. else:
  526. query, key = (x.transpose(1, 0) for x in (query, key))
  527. value = key
  528. else:
  529. query, key, value = (x.transpose(1, 0)
  530. for x in (query, key, value))
  531. if not self._qkv_same_embed_dim:
  532. attn_output, attn_output_weights = multi_head_attention_forward(
  533. query,
  534. key,
  535. value,
  536. self.embed_dim,
  537. self.num_heads,
  538. self.in_proj_weight,
  539. self.in_proj_bias,
  540. self.bias_k,
  541. self.bias_v,
  542. self.add_zero_attn,
  543. self.dropout,
  544. self.out_proj.weight,
  545. self.out_proj.bias,
  546. training=self.is_training(),
  547. key_padding_mask=key_padding_mask,
  548. need_weights=need_weights,
  549. attn_mask=attn_mask,
  550. use_separate_proj_weight=True,
  551. q_proj_weight=self.q_proj_weight,
  552. k_proj_weight=self.k_proj_weight,
  553. v_proj_weight=self.v_proj_weight,
  554. average_attn_weights=average_attn_weights,
  555. is_causal=is_causal)
  556. else:
  557. attn_output, attn_output_weights = multi_head_attention_forward(
  558. query,
  559. key,
  560. value,
  561. self.embed_dim,
  562. self.num_heads,
  563. self.in_proj_weight,
  564. self.in_proj_bias,
  565. self.bias_k,
  566. self.bias_v,
  567. self.add_zero_attn,
  568. self.dropout,
  569. self.out_proj.weight,
  570. self.out_proj.bias,
  571. training=self.is_training(),
  572. key_padding_mask=key_padding_mask,
  573. need_weights=need_weights,
  574. attn_mask=attn_mask,
  575. average_attn_weights=average_attn_weights,
  576. is_causal=is_causal)
  577. if self.batch_first and is_batched:
  578. return attn_output.transpose(1, 0), attn_output_weights
  579. else:
  580. return attn_output, attn_output_weights

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)