|
|
@@ -70,7 +70,7 @@ class LSTM(nn.Module): |
|
|
|
x = x[sort_idx] |
|
|
|
else: |
|
|
|
x = x[:, sort_idx] |
|
|
|
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) |
|
|
|
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) |
|
|
|
output, hx = self.lstm(x, hx) # -> [N,L,C] |
|
|
|
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) |
|
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
|