Browse Source

Update attn.py

main
BBing 2 years ago
parent
commit
37f6921a18
1 changed files with 18 additions and 22 deletions
  1. +18
    -22
      models/attn.py

+ 18
- 22
models/attn.py View File

@@ -42,28 +42,24 @@ class ProbAttention(nn.Module):
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)

def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape

# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)

# find the Top_k query with sparisty measurement
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]

# use the reduced Q to calculate Q_K
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k

return Q_K, M_top
def _prob_QK(self, Q, K, sample_k, n_top):
B, H, L_K, E = K.shape
_, _, L_Q, D = Q.shape

K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = ops.random.randint(0, L_K, (L_Q, sample_k), dtype=mstype.int32)
K_sample = K_expand[:, :, ops.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = ops.matmul(Q.unsqueeze(-2), K_sample.transpose((0, 1, 3, 2))).squeeze(-2)

M = Q_K_sample.max(-1)[0] - ops.div(Q_K_sample.sum(-1), L_K)
_, M_top = ops.top_k(M, n_top, sorted=False)

Q_reduce = Q[ops.arange(B).unsqueeze(1).unsqueeze(2),
ops.arange(H).unsqueeze(0).unsqueeze(2),
M_top, :]
Q_K = ops.matmul(Q_reduce, K.transpose((0, 1, 3, 2)))

return Q_K, M_top

def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape


Loading…
Cancel
Save