| @@ -6,6 +6,14 @@ from torch.nn.utils.rnn import PackedSequence | |||||
| from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
| try: | |||||
| from torch import flip | |||||
| except ImportError: | |||||
| def flip(x, dims): | |||||
| indices = [slice(None)] * x.dim() | |||||
| for dim in dims: | |||||
| indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | |||||
| return x[tuple(indices)] | |||||
| class VarRnnCellWrapper(nn.Module): | class VarRnnCellWrapper(nn.Module): | ||||
| """Wrapper for normal RNN Cells, make it support variational dropout | """Wrapper for normal RNN Cells, make it support variational dropout | ||||
| @@ -102,13 +110,13 @@ class VarRNNBase(nn.Module): | |||||
| for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
| output_list = [] | output_list = [] | ||||
| for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
| input_x = input if direction == 0 else input.flip(0) | |||||
| input_x = input if direction == 0 else flip(input, [0]) | |||||
| idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
| cell = self._all_cells[idx] | cell = self._all_cells[idx] | ||||
| hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | ||||
| mask_xi = mask_x if layer == 0 else mask_out | mask_xi = mask_x if layer == 0 else mask_out | ||||
| output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) | output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) | ||||
| output_list.append(output_x if direction == 0 else output_x.flip(0)) | |||||
| output_list.append(output_x if direction == 0 else flip(output_x, [0])) | |||||
| hidden_list.append(hidden_x) | hidden_list.append(hidden_x) | ||||
| input = torch.cat(output_list, dim=-1) | input = torch.cat(output_list, dim=-1) | ||||