|
@@ -6,6 +6,14 @@ from torch.nn.utils.rnn import PackedSequence |
|
|
|
|
|
|
|
|
from fastNLP.modules.utils import initial_parameter |
|
|
from fastNLP.modules.utils import initial_parameter |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from torch import flip |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
def flip(x, dims): |
|
|
|
|
|
indices = [slice(None)] * x.dim() |
|
|
|
|
|
for dim in dims: |
|
|
|
|
|
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) |
|
|
|
|
|
return x[tuple(indices)] |
|
|
|
|
|
|
|
|
class VarRnnCellWrapper(nn.Module): |
|
|
class VarRnnCellWrapper(nn.Module): |
|
|
"""Wrapper for normal RNN Cells, make it support variational dropout |
|
|
"""Wrapper for normal RNN Cells, make it support variational dropout |
|
@@ -102,13 +110,13 @@ class VarRNNBase(nn.Module): |
|
|
for layer in range(self.num_layers): |
|
|
for layer in range(self.num_layers): |
|
|
output_list = [] |
|
|
output_list = [] |
|
|
for direction in range(self.num_directions): |
|
|
for direction in range(self.num_directions): |
|
|
input_x = input if direction == 0 else input.flip(0) |
|
|
|
|
|
|
|
|
input_x = input if direction == 0 else flip(input, [0]) |
|
|
idx = self.num_directions * layer + direction |
|
|
idx = self.num_directions * layer + direction |
|
|
cell = self._all_cells[idx] |
|
|
cell = self._all_cells[idx] |
|
|
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] |
|
|
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] |
|
|
mask_xi = mask_x if layer == 0 else mask_out |
|
|
mask_xi = mask_x if layer == 0 else mask_out |
|
|
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) |
|
|
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) |
|
|
output_list.append(output_x if direction == 0 else output_x.flip(0)) |
|
|
|
|
|
|
|
|
output_list.append(output_x if direction == 0 else flip(output_x, [0])) |
|
|
hidden_list.append(hidden_x) |
|
|
hidden_list.append(hidden_x) |
|
|
input = torch.cat(output_list, dim=-1) |
|
|
input = torch.cat(output_list, dim=-1) |
|
|
|
|
|
|
|
|