|
|
|
@@ -18,25 +18,26 @@ class FullAttention(nn.Module): |
|
|
|
self.output_attention = output_attention |
|
|
|
self.dropout = nn.Dropout(attention_dropout) |
|
|
|
|
|
|
|
def forward(self, queries, keys, values, attn_mask): |
|
|
|
def construct(self, queries, keys, values, attn_mask): # def forward(self, queries, keys, values, attn_mask): |
|
|
|
B, L, H, E = queries.shape |
|
|
|
_, S, _, D = values.shape |
|
|
|
scale = self.scale or 1./sqrt(E) |
|
|
|
scale = self.scale or 1.0 / np.sqrt(E) |
|
|
|
|
|
|
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys) |
|
|
|
scores = self.matmul(queries, keys.transpose(0, 1, 3, 2)) |
|
|
|
if self.mask_flag: |
|
|
|
if attn_mask is None: |
|
|
|
attn_mask = TriangularCausalMask(B, L, device=queries.device) |
|
|
|
|
|
|
|
scores.masked_fill_(attn_mask.mask, -np.inf) |
|
|
|
scores = scores + (1 - attn_mask.mask) * (-np.inf) |
|
|
|
|
|
|
|
A = self.dropout(torch.softmax(scale * scores, dim=-1)) |
|
|
|
V = torch.einsum("bhls,bshd->blhd", A, values) |
|
|
|
A = self.dropout(self.softmax(scale * scores)) |
|
|
|
V = self.matmul(A, values) |
|
|
|
|
|
|
|
if self.output_attention: |
|
|
|
return (V.contiguous(), A) |
|
|
|
return V.contiguous(), A |
|
|
|
else: |
|
|
|
return (V.contiguous(), None) |
|
|
|
return V.contiguous(), None |
|
|
|
|
|
|
|
|
|
|
|
class ProbAttention(nn.Module): |
|
|
|
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): |
|
|
|
|