From e78fc82a6ea7ca55bcd9ad37e158a96bda181588 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 27 Nov 2021 20:30:43 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3ELMO=E4=B8=8D=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E4=BD=BF=E7=94=A8cuda?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/models/biaffine_parser.py | 2 +- fastNLP/modules/encoder/_elmo.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index dff4809c..cd874e7c 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -376,7 +376,7 @@ class BiaffineParser(GraphParser): if self.encoder_name.endswith('lstm'): sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) x = x[sort_idx] - x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) + x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True) feat, _ = self.encoder(x) # -> [N,L,C] feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 13843f83..7a2cf4bc 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -251,7 +251,7 @@ class LstmbiLm(nn.Module): def forward(self, inputs, seq_len): sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) inputs = inputs[sort_idx] - inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) + inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=self.batch_first) output, hx = self.encoder(inputs, None) # -> [N,L,C] output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) @@ -316,7 +316,7 @@ class ElmobiLm(torch.nn.Module): max_len = inputs.size(1) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) inputs = inputs[sort_idx] - inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) + inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=True) output, _ = self._lstm_forward(inputs, None) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) output = output[:, unsort_idx]