From b7008cba7800a2140bcbaec14370f36862c3eb7c Mon Sep 17 00:00:00 2001 From: Yunfan Shao Date: Fri, 22 Mar 2019 17:37:50 +0800 Subject: [PATCH] fix mask bug in star-transformer fix the bug described in #138 . Thank @wlhgtc for bug reporting and pr. --- fastNLP/modules/encoder/star_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index b28d3d1d..5b9ae7ec 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -46,6 +46,7 @@ class StarTransformer(nn.Module): return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) B, L, H = data.size() + mask = (mask == 0) # flip the mask for masked_fill_ smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1