From 4ac4cda049a85f6cc73b854ac79f2bb549cfee97 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 17 May 2019 13:22:42 +0800 Subject: [PATCH] fix var runn --- fastNLP/modules/encoder/variational_rnn.py | 92 +++++++++++++--------- 1 file changed, 53 insertions(+), 39 deletions(-) diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index 60cdf9c5..753741de 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -11,7 +11,8 @@ except ImportError: def flip(x, dims): indices = [slice(None)] * x.dim() for dim in dims: - indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) + indices[dim] = torch.arange( + x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)] from ..utils import initial_parameter @@ -27,14 +28,14 @@ class VarRnnCellWrapper(nn.Module): """ Wrapper for normal RNN Cells, make it support variational dropout """ - + def __init__(self, cell, hidden_size, input_p, hidden_p): super(VarRnnCellWrapper, self).__init__() self.cell = cell self.hidden_size = hidden_size self.input_p = input_p self.hidden_p = hidden_p - + def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): """ :param PackedSequence input_x: [seq_len, batch_size, input_size] @@ -46,13 +47,13 @@ class VarRnnCellWrapper(nn.Module): hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] for other RNN, h_n, [batch_size, hidden_size] """ - + def get_hi(hi, h0, size): h0_size = size - hi.size(0) if h0_size > 0: return torch.cat([hi, h0[:h0_size]], dim=0) return hi[:size] - + is_lstm = isinstance(hidden, tuple) input, batch_sizes = input_x.data, input_x.batch_sizes output = [] @@ -63,7 +64,7 @@ class VarRnnCellWrapper(nn.Module): else: batch_iter = batch_sizes idx = 0 - + if is_lstm: hn = (hidden[0].clone(), hidden[1].clone()) else: @@ -79,7 +80,8 @@ class VarRnnCellWrapper(nn.Module): mask_hi = mask_h[:size] if is_lstm: hx, cx = hi - hi = (get_hi(hx, hidden[0], size) * mask_hi, get_hi(cx, hidden[1], size)) + hi = (get_hi(hx, hidden[0], size) * + mask_hi, get_hi(cx, hidden[1], size)) hi = cell(input_i, hi) hn[0][:size] = hi[0] hn[1][:size] = hi[1] @@ -89,7 +91,7 @@ class VarRnnCellWrapper(nn.Module): hi = cell(input_i, hi) hn[:size] = hi output.append(hi) - + if is_reversed: output = list(reversed(output)) output = torch.cat(output, dim=0) @@ -99,7 +101,7 @@ class VarRnnCellWrapper(nn.Module): class VarRNNBase(nn.Module): """ Variational Dropout RNN 实现. - + 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) https://arxiv.org/abs/1512.05287`. @@ -115,7 +117,7 @@ class VarRNNBase(nn.Module): :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` """ - + def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, input_dropout=0, hidden_dropout=0, bidirectional=False): @@ -135,18 +137,20 @@ class VarRNNBase(nn.Module): for direction in range(self.num_directions): input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions cell = Cell(input_size, self.hidden_size, bias) - self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) + self._all_cells.append(VarRnnCellWrapper( + cell, self.hidden_size, input_dropout, hidden_dropout)) initial_parameter(self) self.is_lstm = (self.mode == "LSTM") - + def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): is_lstm = self.is_lstm idx = self.num_directions * n_layer + n_direction cell = self._all_cells[idx] hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] - output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) + output_x, hidden_x = cell( + input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) return output_x, hidden_x - + def forward(self, x, hx=None): """ @@ -160,31 +164,38 @@ class VarRNNBase(nn.Module): if not is_packed: seq_len = x.size(1) if self.batch_first else x.size(0) max_batch_size = x.size(0) if self.batch_first else x.size(1) - seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) - input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) + seq_lens = torch.LongTensor( + [seq_len for _ in range(max_batch_size)]) + x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) else: - max_batch_size = int(input.batch_sizes[0]) - input, batch_sizes = input.data, input.batch_sizes - + max_batch_size = int(x.batch_sizes[0]) + x, batch_sizes = x.data, x.batch_sizes + if hx is None: hx = x.new_zeros(self.num_layers * self.num_directions, max_batch_size, self.hidden_size, requires_grad=True) if is_lstm: hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) - + mask_x = x.new_ones((max_batch_size, self.input_size)) - mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions)) + mask_out = x.new_ones( + (max_batch_size, self.hidden_size * self.num_directions)) mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) - nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) - nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) - - hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) + nn.functional.dropout(mask_x, p=self.input_dropout, + training=self.training, inplace=True) + nn.functional.dropout(mask_out, p=self.hidden_dropout, + training=self.training, inplace=True) + + hidden = x.new_zeros( + (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) if is_lstm: - cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) + cellstate = x.new_zeros( + (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) for layer in range(self.num_layers): output_list = [] input_seq = PackedSequence(x, batch_sizes) - mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) + mask_h = nn.functional.dropout( + mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) for direction in range(self.num_directions): output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, mask_x if layer == 0 else mask_out, mask_h) @@ -196,16 +207,16 @@ class VarRNNBase(nn.Module): else: hidden[idx] = hidden_x x = torch.cat(output_list, dim=-1) - + if is_lstm: hidden = (hidden, cellstate) - + if is_packed: output = PackedSequence(x, batch_sizes) else: x = PackedSequence(x, batch_sizes) output, _ = pad_packed_sequence(x, batch_first=self.batch_first) - + return output, hidden @@ -225,10 +236,11 @@ class VarLSTM(VarRNNBase): :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` """ - + def __init__(self, *args, **kwargs): - super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) - + super(VarLSTM, self).__init__( + mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) + def forward(self, x, hx=None): return super(VarLSTM, self).forward(x, hx) @@ -249,10 +261,11 @@ class VarRNN(VarRNNBase): :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` """ - + def __init__(self, *args, **kwargs): - super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) - + super(VarRNN, self).__init__( + mode="RNN", Cell=nn.RNNCell, *args, **kwargs) + def forward(self, x, hx=None): return super(VarRNN, self).forward(x, hx) @@ -273,9 +286,10 @@ class VarGRU(VarRNNBase): :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` """ - + def __init__(self, *args, **kwargs): - super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) - + super(VarGRU, self).__init__( + mode="GRU", Cell=nn.GRUCell, *args, **kwargs) + def forward(self, x, hx=None): return super(VarGRU, self).forward(x, hx)