|
|
@@ -41,7 +41,7 @@ class VarRnnCellWrapper(nn.Module): |
|
|
|
return torch.cat([hi, h0[:h0_size]], dim=0) |
|
|
|
return hi[:size] |
|
|
|
is_lstm = isinstance(hidden, tuple) |
|
|
|
input, batch_sizes = input_x |
|
|
|
input, batch_sizes = input_x.data, input_x.batch_sizes |
|
|
|
output = [] |
|
|
|
cell = self.cell |
|
|
|
if is_reversed: |
|
|
@@ -127,10 +127,10 @@ class VarRNNBase(nn.Module): |
|
|
|
seq_len = input.size(1) if self.batch_first else input.size(0) |
|
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1) |
|
|
|
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) |
|
|
|
input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) |
|
|
|
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) |
|
|
|
else: |
|
|
|
max_batch_size = int(input.batch_sizes[0]) |
|
|
|
input, batch_sizes = input |
|
|
|
input, batch_sizes = input.data, input.batch_sizes |
|
|
|
|
|
|
|
if hx is None: |
|
|
|
hx = input.new_zeros(self.num_layers * self.num_directions, |
|
|
|