| @@ -0,0 +1,650 @@ | |||||
| from typing import Optional, Tuple, List | |||||
| import warnings | |||||
| import math | |||||
| import jittor as jt | |||||
| from jittor import Var | |||||
| from jittor.nn import Module, Linear, softmax, pad, linear, dropout | |||||
| from jittor.init import xavier_uniform_, xavier_gauss_, constant_ | |||||
| def _canonical_mask( | |||||
| mask: Optional[Var], | |||||
| mask_name: str, | |||||
| other_type, | |||||
| other_name: str, | |||||
| target_type, | |||||
| check_other: bool = True, | |||||
| ) -> Optional[Var]: | |||||
| if mask is not None: | |||||
| _mask_dtype = mask.dtype | |||||
| _mask_is_float = mask.dtype == jt.float16 or mask.dtype == jt.float32 or mask.dtype == jt.float64 | |||||
| if _mask_dtype != jt.bool and not _mask_is_float: | |||||
| raise AssertionError( | |||||
| f"only bool and floating types of {mask_name} are supported") | |||||
| if check_other and other_type is not None: | |||||
| if _mask_dtype != other_type: | |||||
| warnings.warn( | |||||
| f"Support for mismatched {mask_name} and {other_name} " | |||||
| "is deprecated. Use same type for both instead.") | |||||
| if not _mask_is_float: | |||||
| # WARNING(514flowey): Check Here | |||||
| new_mask = jt.zeros_like(mask, dtype=target_type) | |||||
| new_mask[mask] = float("-inf") | |||||
| mask = new_mask | |||||
| return mask | |||||
| def _none_or_dtype(input: Optional[Var]): | |||||
| if input is None: | |||||
| return None | |||||
| elif isinstance(input, jt.Var): | |||||
| return input.dtype | |||||
| def baddbmm(input_var: jt.Var, | |||||
| batch1: jt.Var, | |||||
| batch2: jt.Var, | |||||
| beta=1, | |||||
| alpha=1) -> jt.Var: | |||||
| # WARNING(514flowey): Check here | |||||
| return beta * input_var + alpha * (batch1 @ batch2) | |||||
| def scaled_dot_product_attention(query, | |||||
| key, | |||||
| value, | |||||
| attn_mask=None, | |||||
| dropout_p=0.0, | |||||
| is_causal=False, | |||||
| scale=None, | |||||
| training=True) -> jt.Var: | |||||
| # Efficient implementation equivalent to the following: | |||||
| L, S = query.size(-2), key.size(-2) | |||||
| scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | |||||
| attn_bias = jt.zeros(L, S, dtype=query.dtype) | |||||
| if is_causal: | |||||
| assert attn_mask is None | |||||
| temp_mask = jt.ones(L, S, dtype=jt.bool).tril(diagonal=0) | |||||
| attn_bias[jt.logical_not(temp_mask)] = float("-inf") | |||||
| # attn_bias.to(query.dtype) | |||||
| attn_bias = jt.array(attn_bias, query.dtype) | |||||
| if attn_mask is not None: | |||||
| if attn_mask.dtype == jt.bool: | |||||
| attn_bias[jt.logical_not(temp_mask)] = float("-inf") | |||||
| else: | |||||
| attn_bias += attn_mask | |||||
| attn_weight = query @ key.transpose(-2, -1) * scale_factor | |||||
| attn_weight += attn_bias | |||||
| attn_weight = softmax(attn_weight, dim=-1) | |||||
| attn_weight = dropout(attn_weight, dropout_p, is_train=training) | |||||
| return attn_weight @ value | |||||
| def _mha_shape_check(query: Var, key: Var, value: Var, | |||||
| key_padding_mask: Optional[Var], attn_mask: Optional[Var], | |||||
| num_heads: int): | |||||
| if query.dim() == 3: | |||||
| is_batched = True | |||||
| assert key.dim() == 3 and value.dim() == 3, \ | |||||
| ("For batched (3-D) `query`, expected `key` and `value` to be 3-D" | |||||
| f" but found {key.dim()}-D and {value.dim()}-D Vars respectively") | |||||
| if key_padding_mask is not None: | |||||
| assert key_padding_mask.dim() == 2, \ | |||||
| ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" | |||||
| f" but found {key_padding_mask.dim()}-D Var instead") | |||||
| if attn_mask is not None: | |||||
| assert attn_mask.dim() in (2, 3), \ | |||||
| ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" | |||||
| f" but found {attn_mask.dim()}-D Var instead") | |||||
| elif query.dim() == 2: | |||||
| is_batched = False | |||||
| assert key.dim() == 2 and value.dim() == 2, \ | |||||
| ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" | |||||
| f" but found {key.dim()}-D and {value.dim()}-D Vars respectively") | |||||
| if key_padding_mask is not None: | |||||
| assert key_padding_mask.dim() == 1, \ | |||||
| ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" | |||||
| f" but found {key_padding_mask.dim()}-D Var instead") | |||||
| if attn_mask is not None: | |||||
| assert attn_mask.dim() in (2, 3), \ | |||||
| ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" | |||||
| f" but found {attn_mask.dim()}-D Var instead") | |||||
| if attn_mask.dim() == 3: | |||||
| expected_shape = (num_heads, query.shape[0], key.shape[0]) | |||||
| assert attn_mask.shape == expected_shape, \ | |||||
| (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") | |||||
| else: | |||||
| raise AssertionError( | |||||
| f"query should be unbatched 2D or batched 3D Var but received {query.dim()}-D query Var" | |||||
| ) | |||||
| return is_batched | |||||
| def _in_projection_packed( | |||||
| q: Var, | |||||
| k: Var, | |||||
| v: Var, | |||||
| w: Var, | |||||
| b: Optional[Var] = None, | |||||
| ) -> List[Var]: | |||||
| E = q.size(-1) | |||||
| if k is v: | |||||
| if q is k: | |||||
| # self-attention | |||||
| proj = linear(q, w, b) | |||||
| # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() | |||||
| # proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() | |||||
| nshape = proj.shape[:-1] + (3, E) | |||||
| proj = proj.reshape(nshape).unsqueeze(0).transpose(0, | |||||
| -2).squeeze(-2) | |||||
| return proj[0], proj[1], proj[2] | |||||
| else: | |||||
| # encoder-decoder attention | |||||
| w_q, w_kv = w.split([E, E * 2]) | |||||
| if b is None: | |||||
| b_q = b_kv = None | |||||
| else: | |||||
| b_q, b_kv = b.split([E, E * 2]) | |||||
| q_proj = linear(q, w_q, b_q) | |||||
| kv_proj = linear(k, w_kv, b_kv) | |||||
| # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() | |||||
| # kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() | |||||
| nshape = kv_proj.shape[:-1] + (2, E) | |||||
| kv_proj = kv_proj.reshape(nshape).unsqueeze(0).transpose( | |||||
| 0, -2).squeeze(-2) | |||||
| return (q_proj, kv_proj[0], kv_proj[1]) | |||||
| else: | |||||
| w_q, w_k, w_v = w.chunk(3) | |||||
| if b is None: | |||||
| b_q = b_k = b_v = None | |||||
| else: | |||||
| b_q, b_k, b_v = b.chunk(3) | |||||
| return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) | |||||
| def _in_projection( | |||||
| q: Var, | |||||
| k: Var, | |||||
| v: Var, | |||||
| w_q: Var, | |||||
| w_k: Var, | |||||
| w_v: Var, | |||||
| b_q: Optional[Var] = None, | |||||
| b_k: Optional[Var] = None, | |||||
| b_v: Optional[Var] = None, | |||||
| ) -> Tuple[Var, Var, Var]: | |||||
| Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) | |||||
| assert w_q.shape == ( | |||||
| Eq, Eq | |||||
| ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" | |||||
| assert w_k.shape == ( | |||||
| Eq, | |||||
| Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" | |||||
| assert w_v.shape == ( | |||||
| Eq, Ev | |||||
| ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" | |||||
| assert b_q is None or b_q.shape == ( | |||||
| Eq, ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" | |||||
| assert b_k is None or b_k.shape == ( | |||||
| Eq, ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" | |||||
| assert b_v is None or b_v.shape == ( | |||||
| Eq, ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" | |||||
| return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) | |||||
| def multi_head_attention_forward( | |||||
| query: Var, | |||||
| key: Var, | |||||
| value: Var, | |||||
| embed_dim_to_check: int, | |||||
| num_heads: int, | |||||
| in_proj_weight: Optional[Var], | |||||
| in_proj_bias: Optional[Var], | |||||
| bias_k: Optional[Var], | |||||
| bias_v: Optional[Var], | |||||
| add_zero_attn: bool, | |||||
| dropout_p: float, | |||||
| out_proj_weight: Var, | |||||
| out_proj_bias: Optional[Var], | |||||
| training: bool = True, | |||||
| key_padding_mask: Optional[Var] = None, | |||||
| need_weights: bool = True, | |||||
| attn_mask: Optional[Var] = None, | |||||
| use_separate_proj_weight: bool = False, | |||||
| q_proj_weight: Optional[Var] = None, | |||||
| k_proj_weight: Optional[Var] = None, | |||||
| v_proj_weight: Optional[Var] = None, | |||||
| static_k: Optional[Var] = None, | |||||
| static_v: Optional[Var] = None, | |||||
| average_attn_weights: bool = True, | |||||
| is_causal: bool = False, | |||||
| ) -> Tuple[Var, Optional[Var]]: | |||||
| is_batched = _mha_shape_check(query, key, value, key_padding_mask, | |||||
| attn_mask, num_heads) | |||||
| # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input | |||||
| # is batched, run the computation and before returning squeeze the | |||||
| # batch dimension so that the output doesn't carry this temporary batch dimension. | |||||
| if not is_batched: | |||||
| # unsqueeze if the input is unbatched | |||||
| query = query.unsqueeze(1) | |||||
| key = key.unsqueeze(1) | |||||
| value = value.unsqueeze(1) | |||||
| if key_padding_mask is not None: | |||||
| key_padding_mask = key_padding_mask.unsqueeze(0) | |||||
| # set up shape vars | |||||
| tgt_len, bsz, embed_dim = query.shape | |||||
| src_len, _, _ = key.shape | |||||
| key_padding_mask = _canonical_mask(mask=key_padding_mask, | |||||
| mask_name="key_padding_mask", | |||||
| other_type=_none_or_dtype(attn_mask), | |||||
| other_name="attn_mask", | |||||
| target_type=query.dtype) | |||||
| if is_causal and attn_mask is None: | |||||
| raise RuntimeError( | |||||
| "Need attn_mask if specifying the is_causal hint. " | |||||
| "You may use the Transformer module method " | |||||
| "`generate_square_subsequent_mask` to create this mask.") | |||||
| if is_causal and key_padding_mask is None and not need_weights: | |||||
| # when we have a kpm or need weights, we need attn_mask | |||||
| # Otherwise, we use the is_causal hint go as is_causal | |||||
| # indicator to SDPA. | |||||
| attn_mask = None | |||||
| else: | |||||
| attn_mask = _canonical_mask( | |||||
| mask=attn_mask, | |||||
| mask_name="attn_mask", | |||||
| other_type=None, | |||||
| other_name="", | |||||
| target_type=query.dtype, | |||||
| check_other=False, | |||||
| ) | |||||
| if key_padding_mask is not None: | |||||
| # We have the attn_mask, and use that to merge kpm into it. | |||||
| # Turn off use of is_causal hint, as the merged mask is no | |||||
| # longer causal. | |||||
| is_causal = False | |||||
| assert embed_dim == embed_dim_to_check, \ | |||||
| f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" | |||||
| if isinstance(embed_dim, jt.Var): | |||||
| # embed_dim can be a Var when JIT tracing | |||||
| head_dim = embed_dim.div(num_heads, rounding_mode='trunc') | |||||
| else: | |||||
| head_dim = embed_dim // num_heads | |||||
| assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" | |||||
| if use_separate_proj_weight: | |||||
| # allow MHA to have different embedding dimensions when separate projection weights are used | |||||
| assert key.shape[:2] == value.shape[:2], \ | |||||
| f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" | |||||
| else: | |||||
| assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" | |||||
| # | |||||
| # compute in-projection | |||||
| # | |||||
| if not use_separate_proj_weight: | |||||
| assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" | |||||
| q, k, v = _in_projection_packed(query, key, value, in_proj_weight, | |||||
| in_proj_bias) | |||||
| else: | |||||
| assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" | |||||
| assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" | |||||
| assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" | |||||
| if in_proj_bias is None: | |||||
| b_q = b_k = b_v = None | |||||
| else: | |||||
| b_q, b_k, b_v = in_proj_bias.chunk(3) | |||||
| q, k, v = _in_projection(query, key, value, q_proj_weight, | |||||
| k_proj_weight, v_proj_weight, b_q, b_k, b_v) | |||||
| # prep attention mask | |||||
| if attn_mask is not None: | |||||
| # ensure attn_mask's dim is 3 | |||||
| if attn_mask.dim() == 2: | |||||
| correct_2d_size = (tgt_len, src_len) | |||||
| if attn_mask.shape != correct_2d_size: | |||||
| raise RuntimeError( | |||||
| f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." | |||||
| ) | |||||
| attn_mask = attn_mask.unsqueeze(0) | |||||
| elif attn_mask.dim() == 3: | |||||
| correct_3d_size = (bsz * num_heads, tgt_len, src_len) | |||||
| if attn_mask.shape != correct_3d_size: | |||||
| raise RuntimeError( | |||||
| f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." | |||||
| ) | |||||
| else: | |||||
| raise RuntimeError( | |||||
| f"attn_mask's dimension {attn_mask.dim()} is not supported") | |||||
| # add bias along batch dimension (currently second) | |||||
| if bias_k is not None and bias_v is not None: | |||||
| assert static_k is None, "bias cannot be added to static key." | |||||
| assert static_v is None, "bias cannot be added to static value." | |||||
| k = jt.concat([k, bias_k.repeat(1, bsz, 1)]) | |||||
| v = jt.concat([v, bias_v.repeat(1, bsz, 1)]) | |||||
| if attn_mask is not None: | |||||
| attn_mask = pad(attn_mask, (0, 1)) | |||||
| if key_padding_mask is not None: | |||||
| key_padding_mask = pad(key_padding_mask, (0, 1)) | |||||
| else: | |||||
| assert bias_k is None | |||||
| assert bias_v is None | |||||
| # | |||||
| # reshape q, k, v for multihead attention and make em batch first | |||||
| # | |||||
| q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) | |||||
| if static_k is None: | |||||
| k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) | |||||
| else: | |||||
| # TODO finish disentangling control flow so we don't do in-projections when statics are passed | |||||
| assert static_k.size(0) == bsz * num_heads, \ | |||||
| f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" | |||||
| assert static_k.size(2) == head_dim, \ | |||||
| f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" | |||||
| k = static_k | |||||
| if static_v is None: | |||||
| v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) | |||||
| else: | |||||
| # TODO finish disentangling control flow so we don't do in-projections when statics are passed | |||||
| assert static_v.size(0) == bsz * num_heads, \ | |||||
| f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" | |||||
| assert static_v.size(2) == head_dim, \ | |||||
| f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" | |||||
| v = static_v | |||||
| # add zero attention along batch dimension (now first) | |||||
| if add_zero_attn: | |||||
| zero_attn_shape = (bsz * num_heads, 1, head_dim) | |||||
| k = jt.concat([k, jt.zeros(zero_attn_shape, dtype=k.dtype)], dim=1) | |||||
| v = jt.concat([v, jt.zeros(zero_attn_shape, dtype=v.dtype)], dim=1) | |||||
| if attn_mask is not None: | |||||
| attn_mask = pad(attn_mask, (0, 1)) | |||||
| if key_padding_mask is not None: | |||||
| key_padding_mask = pad(key_padding_mask, (0, 1)) | |||||
| # update source sequence length after adjustments | |||||
| src_len = k.size(1) | |||||
| # merge key padding and attention masks | |||||
| if key_padding_mask is not None: | |||||
| assert key_padding_mask.shape == (bsz, src_len), \ | |||||
| f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" | |||||
| key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ | |||||
| expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) | |||||
| if attn_mask is None: | |||||
| attn_mask = key_padding_mask | |||||
| else: | |||||
| attn_mask = attn_mask + key_padding_mask | |||||
| # adjust dropout probability | |||||
| if not training: | |||||
| dropout_p = 0.0 | |||||
| # | |||||
| # (deep breath) calculate attention and out projection | |||||
| # | |||||
| if need_weights: | |||||
| B, Nt, E = q.shape | |||||
| q_scaled = q / math.sqrt(E) | |||||
| assert not (is_causal and attn_mask is None | |||||
| ), "FIXME: is_causal not implemented for need_weights" | |||||
| if attn_mask is not None: | |||||
| attn_output_weights = baddbmm(attn_mask, q_scaled, | |||||
| k.transpose(-2, -1)) | |||||
| else: | |||||
| attn_output_weights = jt.bmm(q_scaled, k.transpose(-2, -1)) | |||||
| attn_output_weights = softmax(attn_output_weights, dim=-1) | |||||
| if dropout_p > 0.0: | |||||
| attn_output_weights = dropout(attn_output_weights, p=dropout_p) | |||||
| attn_output = jt.bmm(attn_output_weights, v) | |||||
| attn_output = attn_output.transpose(0, 1).contiguous().view( | |||||
| tgt_len * bsz, embed_dim) | |||||
| attn_output = linear(attn_output, out_proj_weight, out_proj_bias) | |||||
| attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) | |||||
| # optionally average attention weights over heads | |||||
| attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, | |||||
| src_len) | |||||
| if average_attn_weights: | |||||
| attn_output_weights = attn_output_weights.mean(dim=1) | |||||
| if not is_batched: | |||||
| # squeeze the output if input was unbatched | |||||
| attn_output = attn_output.squeeze(1) | |||||
| attn_output_weights = attn_output_weights.squeeze(0) | |||||
| return attn_output, attn_output_weights | |||||
| else: | |||||
| # attn_mask can be either (L,S) or (N*num_heads, L, S) | |||||
| # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) | |||||
| # in order to match the input for SDPA of (N, num_heads, L, S) | |||||
| if attn_mask is not None: | |||||
| if attn_mask.size(0) == 1 and attn_mask.dim() == 3: | |||||
| attn_mask = attn_mask.unsqueeze(0) | |||||
| else: | |||||
| attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) | |||||
| q = q.view(bsz, num_heads, tgt_len, head_dim) | |||||
| k = k.view(bsz, num_heads, src_len, head_dim) | |||||
| v = v.view(bsz, num_heads, src_len, head_dim) | |||||
| attn_output = scaled_dot_product_attention(q, | |||||
| k, | |||||
| v, | |||||
| attn_mask, | |||||
| dropout_p, | |||||
| is_causal, | |||||
| training=training) | |||||
| attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view( | |||||
| bsz * tgt_len, embed_dim) | |||||
| attn_output = linear(attn_output, out_proj_weight, out_proj_bias) | |||||
| attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) | |||||
| if not is_batched: | |||||
| # squeeze the output if input was unbatched | |||||
| attn_output = attn_output.squeeze(1) | |||||
| return attn_output, None | |||||
| class MultiheadAttention(Module): | |||||
| __constants__ = ['batch_first'] | |||||
| bias_k: Optional[jt.Var] | |||||
| bias_v: Optional[jt.Var] | |||||
| def __init__(self, | |||||
| embed_dim, | |||||
| num_heads, | |||||
| dropout=0., | |||||
| bias=True, | |||||
| add_bias_kv=False, | |||||
| add_zero_attn=False, | |||||
| kdim=None, | |||||
| vdim=None, | |||||
| batch_first=False, | |||||
| dtype=jt.float32) -> None: | |||||
| if embed_dim <= 0 or num_heads <= 0: | |||||
| raise ValueError( | |||||
| f"embed_dim and num_heads must be greater than 0," | |||||
| f" got embed_dim={embed_dim} and num_heads={num_heads} instead" | |||||
| ) | |||||
| factory_kwargs = {'dtype': dtype} | |||||
| super().__init__() | |||||
| self.embed_dim = embed_dim | |||||
| self.kdim = kdim if kdim is not None else embed_dim | |||||
| self.vdim = vdim if vdim is not None else embed_dim | |||||
| self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | |||||
| self.num_heads = num_heads | |||||
| self.dropout = dropout | |||||
| self.batch_first = batch_first | |||||
| self.head_dim = embed_dim // num_heads | |||||
| assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |||||
| if not self._qkv_same_embed_dim: | |||||
| self.q_proj_weight = jt.empty((embed_dim, embed_dim), | |||||
| **factory_kwargs) | |||||
| self.k_proj_weight = jt.empty((embed_dim, self.kdim), | |||||
| **factory_kwargs) | |||||
| self.v_proj_weight = jt.empty((embed_dim, self.vdim), | |||||
| **factory_kwargs) | |||||
| self.in_proj_weight = None | |||||
| else: | |||||
| self.q_proj_weight = None | |||||
| self.k_proj_weight = None | |||||
| self.v_proj_weight = None | |||||
| self.in_proj_weight = jt.empty((3 * embed_dim, embed_dim), | |||||
| **factory_kwargs) | |||||
| if bias: | |||||
| self.in_proj_bias = jt.empty(3 * embed_dim, **factory_kwargs) | |||||
| else: | |||||
| self.in_proj_bias = None | |||||
| self.out_proj = Linear(embed_dim, embed_dim, bias=bias) | |||||
| if add_bias_kv: | |||||
| self.bias_k = jt.empty((1, 1, embed_dim), **factory_kwargs) | |||||
| self.bias_v = jt.empty((1, 1, embed_dim), **factory_kwargs) | |||||
| else: | |||||
| self.bias_k = self.bias_v = None | |||||
| self.add_zero_attn = add_zero_attn | |||||
| self._reset_parameters() | |||||
| def _reset_parameters(self): | |||||
| if self._qkv_same_embed_dim: | |||||
| xavier_uniform_(self.in_proj_weight) | |||||
| else: | |||||
| xavier_uniform_(self.q_proj_weight) | |||||
| xavier_uniform_(self.k_proj_weight) | |||||
| xavier_uniform_(self.v_proj_weight) | |||||
| if self.in_proj_bias is not None: | |||||
| constant_(self.in_proj_bias, 0.) | |||||
| constant_(self.out_proj.bias, 0.) | |||||
| if self.bias_k is not None: | |||||
| xavier_gauss_(self.bias_k) | |||||
| if self.bias_v is not None: | |||||
| xavier_gauss_(self.bias_v) | |||||
| def __setstate__(self, state): | |||||
| # Support loading old MultiheadAttention checkpoints generated by v1.1.0 | |||||
| if '_qkv_same_embed_dim' not in state: | |||||
| state['_qkv_same_embed_dim'] = True | |||||
| super().__setstate__(state) | |||||
| def execute(self, | |||||
| query: Var, | |||||
| key: Var, | |||||
| value: Var, | |||||
| key_padding_mask: Optional[Var] = None, | |||||
| need_weights: bool = True, | |||||
| attn_mask: Optional[Var] = None, | |||||
| average_attn_weights: bool = True, | |||||
| is_causal: bool = False) -> Tuple[Var, Optional[Var]]: | |||||
| ##### | |||||
| # Fast Path is not Supported. | |||||
| ##### | |||||
| is_batched = query.dim() == 3 | |||||
| key_padding_mask = _canonical_mask( | |||||
| mask=key_padding_mask, | |||||
| mask_name="key_padding_mask", | |||||
| other_type=_none_or_dtype(attn_mask), | |||||
| other_name="attn_mask", | |||||
| target_type=query.dtype) | |||||
| attn_mask = _canonical_mask( | |||||
| mask=attn_mask, | |||||
| mask_name="attn_mask", | |||||
| other_type=None, | |||||
| other_name="", | |||||
| target_type=query.dtype, | |||||
| check_other=False, | |||||
| ) | |||||
| if self.batch_first and is_batched: | |||||
| # make sure that the transpose op does not affect the "is" property | |||||
| if key is value: | |||||
| if query is key: | |||||
| query = key = value = query.transpose(1, 0) | |||||
| else: | |||||
| query, key = (x.transpose(1, 0) for x in (query, key)) | |||||
| value = key | |||||
| else: | |||||
| query, key, value = (x.transpose(1, 0) | |||||
| for x in (query, key, value)) | |||||
| if not self._qkv_same_embed_dim: | |||||
| attn_output, attn_output_weights = multi_head_attention_forward( | |||||
| query, | |||||
| key, | |||||
| value, | |||||
| self.embed_dim, | |||||
| self.num_heads, | |||||
| self.in_proj_weight, | |||||
| self.in_proj_bias, | |||||
| self.bias_k, | |||||
| self.bias_v, | |||||
| self.add_zero_attn, | |||||
| self.dropout, | |||||
| self.out_proj.weight, | |||||
| self.out_proj.bias, | |||||
| training=self.is_training(), | |||||
| key_padding_mask=key_padding_mask, | |||||
| need_weights=need_weights, | |||||
| attn_mask=attn_mask, | |||||
| use_separate_proj_weight=True, | |||||
| q_proj_weight=self.q_proj_weight, | |||||
| k_proj_weight=self.k_proj_weight, | |||||
| v_proj_weight=self.v_proj_weight, | |||||
| average_attn_weights=average_attn_weights, | |||||
| is_causal=is_causal) | |||||
| else: | |||||
| attn_output, attn_output_weights = multi_head_attention_forward( | |||||
| query, | |||||
| key, | |||||
| value, | |||||
| self.embed_dim, | |||||
| self.num_heads, | |||||
| self.in_proj_weight, | |||||
| self.in_proj_bias, | |||||
| self.bias_k, | |||||
| self.bias_v, | |||||
| self.add_zero_attn, | |||||
| self.dropout, | |||||
| self.out_proj.weight, | |||||
| self.out_proj.bias, | |||||
| training=self.is_training(), | |||||
| key_padding_mask=key_padding_mask, | |||||
| need_weights=need_weights, | |||||
| attn_mask=attn_mask, | |||||
| average_attn_weights=average_attn_weights, | |||||
| is_causal=is_causal) | |||||
| if self.batch_first and is_batched: | |||||
| return attn_output.transpose(1, 0), attn_output_weights | |||||
| else: | |||||
| return attn_output, attn_output_weights | |||||