From 9fdcafff6a26c8f76ea4aa9bffa920af9cc7ef35 Mon Sep 17 00:00:00 2001 From: johnson7788 Date: Mon, 7 Dec 2020 13:16:17 +0800 Subject: [PATCH] =?UTF-8?q?"bugs:fix=20lstm=20rnn.pack=5Fpadded=5Fsequence?= =?UTF-8?q?=20RuntimeError,=20=E8=AF=A6=E7=BB=86=E4=BF=A1=E6=81=AF:=20http?= =?UTF-8?q?s://github.com/pytorch/pytorch/issues/43227"=20(#345)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit zhicai.guo --- fastNLP/modules/encoder/lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 9bf5c628..06f8bbb7 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -70,7 +70,7 @@ class LSTM(nn.Module): x = x[sort_idx] else: x = x[:, sort_idx] - x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) + x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) output, hx = self.lstm(x, hx) # -> [N,L,C] output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)