From ae3356b0bb6b14dec5fed33efd2e6b0cad76a29a Mon Sep 17 00:00:00 2001 From: Yunfan Shao Date: Fri, 3 May 2019 13:32:36 +0800 Subject: [PATCH] fix for changing torch API --- fastNLP/modules/encoder/variational_rnn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index a7902813..0d58d67b 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -41,7 +41,7 @@ class VarRnnCellWrapper(nn.Module): return torch.cat([hi, h0[:h0_size]], dim=0) return hi[:size] is_lstm = isinstance(hidden, tuple) - input, batch_sizes = input_x + input, batch_sizes = input_x.data, input_x.batch_sizes output = [] cell = self.cell if is_reversed: @@ -127,10 +127,10 @@ class VarRNNBase(nn.Module): seq_len = input.size(1) if self.batch_first else input.size(0) max_batch_size = input.size(0) if self.batch_first else input.size(1) seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) - input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) + input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) else: max_batch_size = int(input.batch_sizes[0]) - input, batch_sizes = input + input, batch_sizes = input.data, input.batch_sizes if hx is None: hx = input.new_zeros(self.num_layers * self.num_directions,