diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 9bf5c628..06f8bbb7 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -70,7 +70,7 @@ class LSTM(nn.Module): x = x[sort_idx] else: x = x[:, sort_idx] - x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) + x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) output, hx = self.lstm(x, hx) # -> [N,L,C] output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)