diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index b28d3d1d..809948bb 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -57,6 +57,7 @@ class StarTransformer(nn.Module): nodes = embs relay = embs.mean(2, keepdim=True) ex_mask = mask[:, None, :, None].expand(B, H, L, 1) + ex_mask = ex_mask.ne(1) # reverse mask for next masked_fill. r_embs = embs.view(B, H, 1, L) for i in range(self.iters): ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)