From 082afbfa45fe31809b2776a46b915b1eea24c142 Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 5 Jan 2019 16:08:40 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DCRF=E4=B8=ADseq=5Flen=5Fto=5F?= =?UTF-8?q?byte=5Fmask=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/decoder/CRF.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index baa8c403..a15f2175 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -15,7 +15,7 @@ def seq_len_to_byte_mask(seq_lens): # return value: ByteTensor, batch_size x max_len batch_size = seq_lens.size(0) max_len = seq_lens.max() - broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1) + broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) return mask