Browse Source

修复CRF中seq_len_to_byte_mask的bug

tags/v0.3.0^2
yh 5 years ago
parent
commit
082afbfa45
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/modules/decoder/CRF.py

+ 1
- 1
fastNLP/modules/decoder/CRF.py View File

@@ -15,7 +15,7 @@ def seq_len_to_byte_mask(seq_lens):
# return value: ByteTensor, batch_size x max_len # return value: ByteTensor, batch_size x max_len
batch_size = seq_lens.size(0) batch_size = seq_lens.size(0)
max_len = seq_lens.max() max_len = seq_lens.max()
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1)
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) mask = broadcast_arange.lt(seq_lens.float().view(-1, 1))
return mask return mask




Loading…
Cancel
Save