Browse Source

Update attn.py

main
BBing 2 years ago
parent
commit
16ca852075
1 changed files with 9 additions and 8 deletions
  1. +9
    -8
      models/attn.py

+ 9
- 8
models/attn.py View File

@@ -18,25 +18,26 @@ class FullAttention(nn.Module):
self.output_attention = output_attention self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout) self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask):
def construct(self, queries, keys, values, attn_mask): # def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape B, L, H, E = queries.shape
_, S, _, D = values.shape _, S, _, D = values.shape
scale = self.scale or 1./sqrt(E)
scale = self.scale or 1.0 / np.sqrt(E)


scores = torch.einsum("blhe,bshe->bhls", queries, keys)
scores = self.matmul(queries, keys.transpose(0, 1, 3, 2))
if self.mask_flag: if self.mask_flag:
if attn_mask is None: if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device) attn_mask = TriangularCausalMask(B, L, device=queries.device)


scores.masked_fill_(attn_mask.mask, -np.inf)
scores = scores + (1 - attn_mask.mask) * (-np.inf)


A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
A = self.dropout(self.softmax(scale * scores))
V = self.matmul(A, values)
if self.output_attention: if self.output_attention:
return (V.contiguous(), A)
return V.contiguous(), A
else: else:
return (V.contiguous(), None)
return V.contiguous(), None


class ProbAttention(nn.Module): class ProbAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):


Loading…
Cancel
Save