Browse Source

Merge pull request #145 from fastnlp/choosewhatulike-patch-1

fix for changing torch API
tags/v0.4.0
Xipeng Qiu GitHub 5 years ago
parent
commit
863a99f741
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      fastNLP/modules/encoder/variational_rnn.py

+ 3
- 3
fastNLP/modules/encoder/variational_rnn.py View File

@@ -41,7 +41,7 @@ class VarRnnCellWrapper(nn.Module):
return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size]
is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x
input, batch_sizes = input_x.data, input_x.batch_sizes
output = []
cell = self.cell
if is_reversed:
@@ -127,10 +127,10 @@ class VarRNNBase(nn.Module):
seq_len = input.size(1) if self.batch_first else input.size(0)
max_batch_size = input.size(0) if self.batch_first else input.size(1)
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)])
input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
else:
max_batch_size = int(input.batch_sizes[0])
input, batch_sizes = input
input, batch_sizes = input.data, input.batch_sizes

if hx is None:
hx = input.new_zeros(self.num_layers * self.num_directions,


Loading…
Cancel
Save