diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index dff4809c..cd874e7c 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -376,7 +376,7 @@ class BiaffineParser(GraphParser): if self.encoder_name.endswith('lstm'): sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) x = x[sort_idx] - x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) + x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True) feat, _ = self.encoder(x) # -> [N,L,C] feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 13843f83..7a2cf4bc 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -251,7 +251,7 @@ class LstmbiLm(nn.Module): def forward(self, inputs, seq_len): sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) inputs = inputs[sort_idx] - inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) + inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=self.batch_first) output, hx = self.encoder(inputs, None) # -> [N,L,C] output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) @@ -316,7 +316,7 @@ class ElmobiLm(torch.nn.Module): max_len = inputs.size(1) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) inputs = inputs[sort_idx] - inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) + inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=True) output, _ = self._lstm_forward(inputs, None) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) output = output[:, unsort_idx]