|
|
@@ -148,11 +148,10 @@ class VarRNNBase(nn.Module): |
|
|
|
seq_len = x.size(1) if self.batch_first else x.size(0) |
|
|
|
max_batch_size = x.size(0) if self.batch_first else x.size(1) |
|
|
|
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) |
|
|
|
_tmp = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) |
|
|
|
x, batch_sizes = _tmp.data, _tmp.batch_sizes |
|
|
|
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) |
|
|
|
else: |
|
|
|
max_batch_size = int(x.batch_sizes[0]) |
|
|
|
x, batch_sizes = x.data, x.batch_sizes |
|
|
|
max_batch_size = int(input.batch_sizes[0]) |
|
|
|
input, batch_sizes = input.data, input.batch_sizes |
|
|
|
|
|
|
|
if hx is None: |
|
|
|
hx = x.new_zeros(self.num_layers * self.num_directions, |
|
|
|