Browse Source

"bugs:fix lstm rnn.pack_padded_sequence RuntimeError, 详细信息: https://github.com/pytorch/pytorch/issues/43227" (#345)

zhicai.guo <zhicai.guo@lavector.com>
tags/v1.0.0alpha
johnson7788 GitHub 4 years ago
parent
commit
9fdcafff6a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/modules/encoder/lstm.py

+ 1
- 1
fastNLP/modules/encoder/lstm.py View File

@@ -70,7 +70,7 @@ class LSTM(nn.Module):
x = x[sort_idx]
else:
x = x[:, sort_idx]
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first)
output, hx = self.lstm(x, hx) # -> [N,L,C]
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)


Loading…
Cancel
Save