| @@ -5,8 +5,9 @@ import torch.nn as nn | |||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | ||||
| from torch.nn.parameter import Parameter | from torch.nn.parameter import Parameter | ||||
| from torch.nn.utils.rnn import PackedSequence | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| # from fastNLP.modules.utils import initial_parameter | |||||
| def default_initializer(hidden_size): | def default_initializer(hidden_size): | ||||
| stdv = 1.0 / math.sqrt(hidden_size) | stdv = 1.0 / math.sqrt(hidden_size) | ||||
| @@ -383,3 +384,132 @@ class VarFastLSTMCell(VarRNNCellBase): | |||||
| hy = outgate * F.tanh(cy) | hy = outgate * F.tanh(cy) | ||||
| return hy, cy | return hy, cy | ||||
| class VarRnnCellWrapper(nn.Module): | |||||
| def __init__(self, cell, hidden_size, input_p, hidden_p): | |||||
| super(VarRnnCellWrapper, self).__init__() | |||||
| self.cell = cell | |||||
| self.hidden_size = hidden_size | |||||
| self.input_p = input_p | |||||
| self.hidden_p = hidden_p | |||||
| def forward(self, input, hidden): | |||||
| """ | |||||
| :param input: [seq_len, batch_size, input_size] | |||||
| :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | |||||
| for other RNN, h_0, [batch_size, hidden_size] | |||||
| :return output: [seq_len, bacth_size, hidden_size] | |||||
| hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||||
| for other RNN, h_n, [batch_size, hidden_size] | |||||
| """ | |||||
| is_lstm = isinstance(hidden, tuple) | |||||
| _, batch_size, input_size = input.shape | |||||
| mask_x = input.new_ones((batch_size, input_size)) | |||||
| mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
| nn.functional.dropout(mask_x, p=self.input_p, training=self.training, inplace=True) | |||||
| nn.functional.dropout(mask_h, p=self.hidden_p, training=self.training, inplace=True) | |||||
| input_x = input * mask_x.unsqueeze(0) | |||||
| output_list = [] | |||||
| for x in input_x: | |||||
| if is_lstm: | |||||
| hx, cx = hidden | |||||
| hidden = (hx * mask_h, cx) | |||||
| else: | |||||
| hidden *= mask_h | |||||
| hidden = self.cell(x, hidden) | |||||
| output_list.append(hidden[0] if is_lstm else hidden) | |||||
| output = torch.stack(output_list, dim=0) | |||||
| return output, hidden | |||||
| class VarRNNBase(nn.Module): | |||||
| def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||||
| bias=True, batch_first=False, | |||||
| input_dropout=0, hidden_dropout=0, bidirectional=False): | |||||
| super(VarRNNBase, self).__init__() | |||||
| self.mode = mode | |||||
| self.input_size = input_size | |||||
| self.hidden_size = hidden_size | |||||
| self.num_layers = num_layers | |||||
| self.bias = bias | |||||
| self.batch_first = batch_first | |||||
| self.input_dropout = input_dropout | |||||
| self.hidden_dropout = hidden_dropout | |||||
| self.bidirectional = bidirectional | |||||
| self.num_directions = 2 if bidirectional else 1 | |||||
| self._all_cells = nn.ModuleList() | |||||
| for layer in range(self.num_layers): | |||||
| for direction in range(self.num_directions): | |||||
| input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | |||||
| cell = Cell(input_size, self.hidden_size, bias) | |||||
| self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | |||||
| def forward(self, input, hx=None): | |||||
| is_packed = isinstance(input, PackedSequence) | |||||
| is_lstm = (self.mode == "LSTM") | |||||
| if is_packed: | |||||
| input, batch_sizes = input | |||||
| max_batch_size = int(batch_sizes[0]) | |||||
| else: | |||||
| batch_sizes = None | |||||
| max_batch_size = input.size(0) if self.batch_first else input.size(1) | |||||
| if hx is None: | |||||
| hx = input.new_zeros(self.num_layers * self.num_directions, | |||||
| max_batch_size, self.hidden_size, | |||||
| requires_grad=False) | |||||
| if is_lstm: | |||||
| hx = (hx, hx) | |||||
| if self.batch_first: | |||||
| input = input.transpose(0, 1) | |||||
| hidden_list = [] | |||||
| for layer in range(self.num_layers): | |||||
| output_list = [] | |||||
| for direction in range(self.num_directions): | |||||
| input_x = input if direction == 0 else input.flip(0) | |||||
| idx = self.num_directions * layer + direction | |||||
| cell = self._all_cells[idx] | |||||
| output_x, hidden_x = cell(input_x, (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]) | |||||
| output_list.append(output_x if direction == 0 else output_x.flip(0)) | |||||
| hidden_list.append(hidden_x) | |||||
| input = torch.cat(output_list, dim=-1) | |||||
| output = input.transpose(0, 1) if self.batch_first else input | |||||
| if is_lstm: | |||||
| h_list, c_list = zip(*hidden_list) | |||||
| hn = torch.stack(h_list, dim=0) | |||||
| cn = torch.stack(c_list, dim=0) | |||||
| hidden = (hn, cn) | |||||
| else: | |||||
| hidden = torch.stack(hidden_list, dim=0) | |||||
| if is_packed: | |||||
| output = PackedSequence(output, batch_sizes) | |||||
| return output, hidden | |||||
| class VarLSTM(VarRNNBase): | |||||
| def __init__(self, *args, **kwargs): | |||||
| super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||||
| if __name__ == '__main__': | |||||
| net = VarLSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True, input_dropout=0.33, hidden_dropout=0.33) | |||||
| lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True) | |||||
| x = torch.randn(2, 8, 10) | |||||
| y, hidden = net(x) | |||||
| y0, h0 = lstm(x) | |||||
| print(y.shape) | |||||
| print(y0.shape) | |||||
| print(y) | |||||
| print(hidden[0]) | |||||
| print(hidden[0].shape) | |||||
| print(y0) | |||||
| print(h0[0]) | |||||
| print(h0[0].shape) | |||||