Browse Source

fix mask bug in star-transformer

fix the bug described in #138 . Thank @wlhgtc for bug reporting and pr.
tags/v0.4.0
Yunfan Shao GitHub 5 years ago
parent
commit
b7008cba78
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 0 deletions
  1. +1
    -0
      fastNLP/modules/encoder/star_transformer.py

+ 1
- 0
fastNLP/modules/encoder/star_transformer.py View File

@@ -46,6 +46,7 @@ class StarTransformer(nn.Module):
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)


B, L, H = data.size() B, L, H = data.size()
mask = (mask == 0) # flip the mask for masked_fill_
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)


embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1 embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1


Loading…
Cancel
Save