From a51ede46f7733e9f18c85182d100d51a35d1d2b7 Mon Sep 17 00:00:00 2001 From: yunfan Date: Tue, 9 Oct 2018 12:52:04 +0800 Subject: [PATCH] update var_rnn --- fastNLP/modules/encoder/variational_rnn.py | 132 ++++++++++++++++++++- 1 file changed, 131 insertions(+), 1 deletion(-) diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index fb75fabb..6702aa8c 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -5,8 +5,9 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend from torch.nn.parameter import Parameter +from torch.nn.utils.rnn import PackedSequence -from fastNLP.modules.utils import initial_parameter +# from fastNLP.modules.utils import initial_parameter def default_initializer(hidden_size): stdv = 1.0 / math.sqrt(hidden_size) @@ -383,3 +384,132 @@ class VarFastLSTMCell(VarRNNCellBase): hy = outgate * F.tanh(cy) return hy, cy + + +class VarRnnCellWrapper(nn.Module): + def __init__(self, cell, hidden_size, input_p, hidden_p): + super(VarRnnCellWrapper, self).__init__() + self.cell = cell + self.hidden_size = hidden_size + self.input_p = input_p + self.hidden_p = hidden_p + + def forward(self, input, hidden): + """ + :param input: [seq_len, batch_size, input_size] + :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] + for other RNN, h_0, [batch_size, hidden_size] + + :return output: [seq_len, bacth_size, hidden_size] + hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] + for other RNN, h_n, [batch_size, hidden_size] + """ + is_lstm = isinstance(hidden, tuple) + _, batch_size, input_size = input.shape + mask_x = input.new_ones((batch_size, input_size)) + mask_h = input.new_ones((batch_size, self.hidden_size)) + nn.functional.dropout(mask_x, p=self.input_p, training=self.training, inplace=True) + nn.functional.dropout(mask_h, p=self.hidden_p, training=self.training, inplace=True) + + input_x = input * mask_x.unsqueeze(0) + output_list = [] + for x in input_x: + if is_lstm: + hx, cx = hidden + hidden = (hx * mask_h, cx) + else: + hidden *= mask_h + hidden = self.cell(x, hidden) + output_list.append(hidden[0] if is_lstm else hidden) + output = torch.stack(output_list, dim=0) + return output, hidden + + +class VarRNNBase(nn.Module): + def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, + bias=True, batch_first=False, + input_dropout=0, hidden_dropout=0, bidirectional=False): + super(VarRNNBase, self).__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.input_dropout = input_dropout + self.hidden_dropout = hidden_dropout + self.bidirectional = bidirectional + self.num_directions = 2 if bidirectional else 1 + self._all_cells = nn.ModuleList() + for layer in range(self.num_layers): + for direction in range(self.num_directions): + input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions + cell = Cell(input_size, self.hidden_size, bias) + self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) + + def forward(self, input, hx=None): + is_packed = isinstance(input, PackedSequence) + is_lstm = (self.mode == "LSTM") + if is_packed: + input, batch_sizes = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + + if hx is None: + hx = input.new_zeros(self.num_layers * self.num_directions, + max_batch_size, self.hidden_size, + requires_grad=False) + if is_lstm: + hx = (hx, hx) + + if self.batch_first: + input = input.transpose(0, 1) + + hidden_list = [] + for layer in range(self.num_layers): + output_list = [] + for direction in range(self.num_directions): + input_x = input if direction == 0 else input.flip(0) + idx = self.num_directions * layer + direction + cell = self._all_cells[idx] + output_x, hidden_x = cell(input_x, (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]) + output_list.append(output_x if direction == 0 else output_x.flip(0)) + hidden_list.append(hidden_x) + input = torch.cat(output_list, dim=-1) + + output = input.transpose(0, 1) if self.batch_first else input + if is_lstm: + h_list, c_list = zip(*hidden_list) + hn = torch.stack(h_list, dim=0) + cn = torch.stack(c_list, dim=0) + hidden = (hn, cn) + else: + hidden = torch.stack(hidden_list, dim=0) + + if is_packed: + output = PackedSequence(output, batch_sizes) + + return output, hidden + + +class VarLSTM(VarRNNBase): + def __init__(self, *args, **kwargs): + super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) + + +if __name__ == '__main__': + net = VarLSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True, input_dropout=0.33, hidden_dropout=0.33) + lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True) + x = torch.randn(2, 8, 10) + y, hidden = net(x) + y0, h0 = lstm(x) + print(y.shape) + print(y0.shape) + print(y) + print(hidden[0]) + print(hidden[0].shape) + print(y0) + print(h0[0]) + print(h0[0].shape) \ No newline at end of file