| @@ -42,28 +42,24 @@ class ProbAttention(nn.Module): | |||||
| self.output_attention = output_attention | self.output_attention = output_attention | ||||
| self.dropout = nn.Dropout(attention_dropout) | self.dropout = nn.Dropout(attention_dropout) | ||||
| def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) | |||||
| # Q [B, H, L, D] | |||||
| B, H, L_K, E = K.shape | |||||
| _, _, L_Q, _ = Q.shape | |||||
| # calculate the sampled Q_K | |||||
| K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) | |||||
| index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q | |||||
| K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] | |||||
| Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) | |||||
| # find the Top_k query with sparisty measurement | |||||
| M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) | |||||
| M_top = M.topk(n_top, sorted=False)[1] | |||||
| # use the reduced Q to calculate Q_K | |||||
| Q_reduce = Q[torch.arange(B)[:, None, None], | |||||
| torch.arange(H)[None, :, None], | |||||
| M_top, :] # factor*ln(L_q) | |||||
| Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k | |||||
| return Q_K, M_top | |||||
| def _prob_QK(self, Q, K, sample_k, n_top): | |||||
| B, H, L_K, E = K.shape | |||||
| _, _, L_Q, D = Q.shape | |||||
| K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) | |||||
| index_sample = ops.random.randint(0, L_K, (L_Q, sample_k), dtype=mstype.int32) | |||||
| K_sample = K_expand[:, :, ops.arange(L_Q).unsqueeze(1), index_sample, :] | |||||
| Q_K_sample = ops.matmul(Q.unsqueeze(-2), K_sample.transpose((0, 1, 3, 2))).squeeze(-2) | |||||
| M = Q_K_sample.max(-1)[0] - ops.div(Q_K_sample.sum(-1), L_K) | |||||
| _, M_top = ops.top_k(M, n_top, sorted=False) | |||||
| Q_reduce = Q[ops.arange(B).unsqueeze(1).unsqueeze(2), | |||||
| ops.arange(H).unsqueeze(0).unsqueeze(2), | |||||
| M_top, :] | |||||
| Q_K = ops.matmul(Q_reduce, K.transpose((0, 1, 3, 2))) | |||||
| return Q_K, M_top | |||||
| def _get_initial_context(self, V, L_Q): | def _get_initial_context(self, V, L_Q): | ||||
| B, H, L_V, D = V.shape | B, H, L_V, D = V.shape | ||||