| @@ -18,25 +18,26 @@ class FullAttention(nn.Module): | |||||
| self.output_attention = output_attention | self.output_attention = output_attention | ||||
| self.dropout = nn.Dropout(attention_dropout) | self.dropout = nn.Dropout(attention_dropout) | ||||
| def forward(self, queries, keys, values, attn_mask): | |||||
| def construct(self, queries, keys, values, attn_mask): # def forward(self, queries, keys, values, attn_mask): | |||||
| B, L, H, E = queries.shape | B, L, H, E = queries.shape | ||||
| _, S, _, D = values.shape | _, S, _, D = values.shape | ||||
| scale = self.scale or 1./sqrt(E) | |||||
| scale = self.scale or 1.0 / np.sqrt(E) | |||||
| scores = torch.einsum("blhe,bshe->bhls", queries, keys) | |||||
| scores = self.matmul(queries, keys.transpose(0, 1, 3, 2)) | |||||
| if self.mask_flag: | if self.mask_flag: | ||||
| if attn_mask is None: | if attn_mask is None: | ||||
| attn_mask = TriangularCausalMask(B, L, device=queries.device) | attn_mask = TriangularCausalMask(B, L, device=queries.device) | ||||
| scores.masked_fill_(attn_mask.mask, -np.inf) | |||||
| scores = scores + (1 - attn_mask.mask) * (-np.inf) | |||||
| A = self.dropout(torch.softmax(scale * scores, dim=-1)) | |||||
| V = torch.einsum("bhls,bshd->blhd", A, values) | |||||
| A = self.dropout(self.softmax(scale * scores)) | |||||
| V = self.matmul(A, values) | |||||
| if self.output_attention: | if self.output_attention: | ||||
| return (V.contiguous(), A) | |||||
| return V.contiguous(), A | |||||
| else: | else: | ||||
| return (V.contiguous(), None) | |||||
| return V.contiguous(), None | |||||
| class ProbAttention(nn.Module): | class ProbAttention(nn.Module): | ||||
| def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): | ||||