Browse Source

fix tests

tags/v0.2.0
yunfan 6 years ago
parent
commit
b19de5278c
1 changed files with 10 additions and 2 deletions
  1. +10
    -2
      fastNLP/modules/encoder/variational_rnn.py

+ 10
- 2
fastNLP/modules/encoder/variational_rnn.py View File

@@ -6,6 +6,14 @@ from torch.nn.utils.rnn import PackedSequence


from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter


try:
from torch import flip
except ImportError:
def flip(x, dims):
indices = [slice(None)] * x.dim()
for dim in dims:
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)]


class VarRnnCellWrapper(nn.Module): class VarRnnCellWrapper(nn.Module):
"""Wrapper for normal RNN Cells, make it support variational dropout """Wrapper for normal RNN Cells, make it support variational dropout
@@ -102,13 +110,13 @@ class VarRNNBase(nn.Module):
for layer in range(self.num_layers): for layer in range(self.num_layers):
output_list = [] output_list = []
for direction in range(self.num_directions): for direction in range(self.num_directions):
input_x = input if direction == 0 else input.flip(0)
input_x = input if direction == 0 else flip(input, [0])
idx = self.num_directions * layer + direction idx = self.num_directions * layer + direction
cell = self._all_cells[idx] cell = self._all_cells[idx]
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]
mask_xi = mask_x if layer == 0 else mask_out mask_xi = mask_x if layer == 0 else mask_out
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h)
output_list.append(output_x if direction == 0 else output_x.flip(0))
output_list.append(output_x if direction == 0 else flip(output_x, [0]))
hidden_list.append(hidden_x) hidden_list.append(hidden_x)
input = torch.cat(output_list, dim=-1) input = torch.cat(output_list, dim=-1)




Loading…
Cancel
Save