Browse Source

fix the "masked_fill" bug

If you use masked_fill according to ex_mask (0 for pad), it will fill not padding position(which value in ex_mask is 1) with 0, this will lead a bad performance.
tags/v0.4.0
wlhgtc GitHub 5 years ago
parent
commit
8d61cd684e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 0 deletions
  1. +1
    -0
      fastNLP/modules/encoder/star_transformer.py

+ 1
- 0
fastNLP/modules/encoder/star_transformer.py View File

@@ -57,6 +57,7 @@ class StarTransformer(nn.Module):
nodes = embs
relay = embs.mean(2, keepdim=True)
ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
ex_mask = ex_mask.ne(1) # reverse mask for next masked_fill.
r_embs = embs.view(B, H, 1, L)
for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)


Loading…
Cancel
Save