From 37f6921a18e3cad7da812a41108543496e9f1e10 Mon Sep 17 00:00:00 2001 From: BBing <67486385@qq.com> Date: Wed, 4 Oct 2023 10:23:02 +0800 Subject: [PATCH] Update attn.py --- models/attn.py | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/models/attn.py b/models/attn.py index 969b93f..5057c36 100644 --- a/models/attn.py +++ b/models/attn.py @@ -42,28 +42,24 @@ class ProbAttention(nn.Module): self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) - def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) - # Q [B, H, L, D] - B, H, L_K, E = K.shape - _, _, L_Q, _ = Q.shape - - # calculate the sampled Q_K - K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) - index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q - K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] - Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) - - # find the Top_k query with sparisty measurement - M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) - M_top = M.topk(n_top, sorted=False)[1] - - # use the reduced Q to calculate Q_K - Q_reduce = Q[torch.arange(B)[:, None, None], - torch.arange(H)[None, :, None], - M_top, :] # factor*ln(L_q) - Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k - - return Q_K, M_top + def _prob_QK(self, Q, K, sample_k, n_top): + B, H, L_K, E = K.shape + _, _, L_Q, D = Q.shape + + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + index_sample = ops.random.randint(0, L_K, (L_Q, sample_k), dtype=mstype.int32) + K_sample = K_expand[:, :, ops.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = ops.matmul(Q.unsqueeze(-2), K_sample.transpose((0, 1, 3, 2))).squeeze(-2) + + M = Q_K_sample.max(-1)[0] - ops.div(Q_K_sample.sum(-1), L_K) + _, M_top = ops.top_k(M, n_top, sorted=False) + + Q_reduce = Q[ops.arange(B).unsqueeze(1).unsqueeze(2), + ops.arange(H).unsqueeze(0).unsqueeze(2), + M_top, :] + Q_K = ops.matmul(Q_reduce, K.transpose((0, 1, 3, 2))) + + return Q_K, M_top def _get_initial_context(self, V, L_Q): B, H, L_V, D = V.shape