Browse Source

Merge branch 'dev' of github.com:fastnlp/fastNLP into dev

tags/v1.0.0alpha
yh_cc 4 years ago
parent
commit
d4fda68840
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/modules/encoder/lstm.py

+ 1
- 1
fastNLP/modules/encoder/lstm.py View File

@@ -70,7 +70,7 @@ class LSTM(nn.Module):
x = x[sort_idx]
else:
x = x[:, sort_idx]
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first)
output, hx = self.lstm(x, hx) # -> [N,L,C]
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)


Loading…
Cancel
Save