|
|
|
@@ -42,28 +42,24 @@ class ProbAttention(nn.Module): |
|
|
|
self.output_attention = output_attention |
|
|
|
self.dropout = nn.Dropout(attention_dropout) |
|
|
|
|
|
|
|
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) |
|
|
|
# Q [B, H, L, D] |
|
|
|
B, H, L_K, E = K.shape |
|
|
|
_, _, L_Q, _ = Q.shape |
|
|
|
|
|
|
|
# calculate the sampled Q_K |
|
|
|
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) |
|
|
|
index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q |
|
|
|
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] |
|
|
|
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) |
|
|
|
|
|
|
|
# find the Top_k query with sparisty measurement |
|
|
|
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) |
|
|
|
M_top = M.topk(n_top, sorted=False)[1] |
|
|
|
|
|
|
|
# use the reduced Q to calculate Q_K |
|
|
|
Q_reduce = Q[torch.arange(B)[:, None, None], |
|
|
|
torch.arange(H)[None, :, None], |
|
|
|
M_top, :] # factor*ln(L_q) |
|
|
|
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k |
|
|
|
|
|
|
|
return Q_K, M_top |
|
|
|
def _prob_QK(self, Q, K, sample_k, n_top): |
|
|
|
B, H, L_K, E = K.shape |
|
|
|
_, _, L_Q, D = Q.shape |
|
|
|
|
|
|
|
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) |
|
|
|
index_sample = ops.random.randint(0, L_K, (L_Q, sample_k), dtype=mstype.int32) |
|
|
|
K_sample = K_expand[:, :, ops.arange(L_Q).unsqueeze(1), index_sample, :] |
|
|
|
Q_K_sample = ops.matmul(Q.unsqueeze(-2), K_sample.transpose((0, 1, 3, 2))).squeeze(-2) |
|
|
|
|
|
|
|
M = Q_K_sample.max(-1)[0] - ops.div(Q_K_sample.sum(-1), L_K) |
|
|
|
_, M_top = ops.top_k(M, n_top, sorted=False) |
|
|
|
|
|
|
|
Q_reduce = Q[ops.arange(B).unsqueeze(1).unsqueeze(2), |
|
|
|
ops.arange(H).unsqueeze(0).unsqueeze(2), |
|
|
|
M_top, :] |
|
|
|
Q_K = ops.matmul(Q_reduce, K.transpose((0, 1, 3, 2))) |
|
|
|
|
|
|
|
return Q_K, M_top |
|
|
|
|
|
|
|
def _get_initial_context(self, V, L_Q): |
|
|
|
B, H, L_V, D = V.shape |
|
|
|
|