diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index a15f2175..91064da2 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -16,7 +16,7 @@ def seq_len_to_byte_mask(seq_lens): batch_size = seq_lens.size(0) max_len = seq_lens.max() broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) - mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) + mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1)) return mask def allowed_transitions(id2label, encoding_type='bio'):