From b19de5278cb62e75dd0b24456bb5396670ffc74c Mon Sep 17 00:00:00 2001 From: yunfan Date: Wed, 10 Oct 2018 10:22:16 +0800 Subject: [PATCH] fix tests --- fastNLP/modules/encoder/variational_rnn.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index 3b2084ce..16bd4172 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -6,6 +6,14 @@ from torch.nn.utils.rnn import PackedSequence from fastNLP.modules.utils import initial_parameter +try: + from torch import flip +except ImportError: + def flip(x, dims): + indices = [slice(None)] * x.dim() + for dim in dims: + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) + return x[tuple(indices)] class VarRnnCellWrapper(nn.Module): """Wrapper for normal RNN Cells, make it support variational dropout @@ -102,13 +110,13 @@ class VarRNNBase(nn.Module): for layer in range(self.num_layers): output_list = [] for direction in range(self.num_directions): - input_x = input if direction == 0 else input.flip(0) + input_x = input if direction == 0 else flip(input, [0]) idx = self.num_directions * layer + direction cell = self._all_cells[idx] hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] mask_xi = mask_x if layer == 0 else mask_out output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) - output_list.append(output_x if direction == 0 else output_x.flip(0)) + output_list.append(output_x if direction == 0 else flip(output_x, [0])) hidden_list.append(hidden_x) input = torch.cat(output_list, dim=-1)