import numpy as np import mindspore.nn as nn import mindspore.ops as ops from mindspore import Tensor from mindspore.common import dtype as mstype class FullAttention(nn.Module): def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): super(FullAttention, self).__init__() self.scale = scale self.mask_flag = mask_flag self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) def construct(self, queries, keys, values, attn_mask): # def forward(self, queries, keys, values, attn_mask): B, L, H, E = queries.shape _, S, _, D = values.shape scale = self.scale or 1.0 / np.sqrt(E) scores = self.matmul(queries, keys.transpose(0, 1, 3, 2)) if self.mask_flag: if attn_mask is None: attn_mask = TriangularCausalMask(B, L, device=queries.device) scores = scores + (1 - attn_mask.mask) * (-np.inf) A = self.dropout(self.softmax(scale * scores)) V = self.matmul(A, values) if self.output_attention: return V.contiguous(), A else: return V.contiguous(), None class ProbAttention(nn.Module): def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): super(ProbAttention, self).__init__() self.factor = factor self.scale = scale self.mask_flag = mask_flag self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) def _prob_QK(self, Q, K, sample_k, n_top): B, H, L_K, E = K.shape _, _, L_Q, D = Q.shape K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) index_sample = ops.random.randint(0, L_K, (L_Q, sample_k), dtype=mstype.int32) K_sample = K_expand[:, :, ops.arange(L_Q).unsqueeze(1), index_sample, :] Q_K_sample = ops.matmul(Q.unsqueeze(-2), K_sample.transpose((0, 1, 3, 2))).squeeze(-2) M = Q_K_sample.max(-1)[0] - ops.div(Q_K_sample.sum(-1), L_K) _, M_top = ops.top_k(M, n_top, sorted=False) Q_reduce = Q[ops.arange(B).unsqueeze(1).unsqueeze(2), ops.arange(H).unsqueeze(0).unsqueeze(2), M_top, :] Q_K = ops.matmul(Q_reduce, K.transpose((0, 1, 3, 2))) return Q_K, M_top def _get_initial_context(self, V, L_Q): B, H, L_V, D = V.shape if not self.mask_flag: # V_sum = V.sum(dim=-2) V_sum = V.mean(dim=-2) contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() else: # use mask assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only contex = V.cumsum(dim=-2) return contex def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): B, H, L_V, D = V.shape if self.mask_flag: attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) scores.masked_fill_(attn_mask.mask, -np.inf) attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul(attn, V).type_as(context_in) if self.output_attention: attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device) attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn return (context_in, attns) else: return (context_in, None) def forward(self, queries, keys, values, attn_mask): B, L_Q, H, D = queries.shape _, L_K, _, _ = keys.shape queries = queries.transpose(2,1) keys = keys.transpose(2,1) values = values.transpose(2,1) U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) U_part = U_part if U_part