| @@ -5,8 +5,9 @@ import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | |||
| from torch.nn.parameter import Parameter | |||
| from torch.nn.utils.rnn import PackedSequence | |||
| from fastNLP.modules.utils import initial_parameter | |||
| # from fastNLP.modules.utils import initial_parameter | |||
| def default_initializer(hidden_size): | |||
| stdv = 1.0 / math.sqrt(hidden_size) | |||
| @@ -383,3 +384,132 @@ class VarFastLSTMCell(VarRNNCellBase): | |||
| hy = outgate * F.tanh(cy) | |||
| return hy, cy | |||
| class VarRnnCellWrapper(nn.Module): | |||
| def __init__(self, cell, hidden_size, input_p, hidden_p): | |||
| super(VarRnnCellWrapper, self).__init__() | |||
| self.cell = cell | |||
| self.hidden_size = hidden_size | |||
| self.input_p = input_p | |||
| self.hidden_p = hidden_p | |||
| def forward(self, input, hidden): | |||
| """ | |||
| :param input: [seq_len, batch_size, input_size] | |||
| :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | |||
| for other RNN, h_0, [batch_size, hidden_size] | |||
| :return output: [seq_len, bacth_size, hidden_size] | |||
| hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||
| for other RNN, h_n, [batch_size, hidden_size] | |||
| """ | |||
| is_lstm = isinstance(hidden, tuple) | |||
| _, batch_size, input_size = input.shape | |||
| mask_x = input.new_ones((batch_size, input_size)) | |||
| mask_h = input.new_ones((batch_size, self.hidden_size)) | |||
| nn.functional.dropout(mask_x, p=self.input_p, training=self.training, inplace=True) | |||
| nn.functional.dropout(mask_h, p=self.hidden_p, training=self.training, inplace=True) | |||
| input_x = input * mask_x.unsqueeze(0) | |||
| output_list = [] | |||
| for x in input_x: | |||
| if is_lstm: | |||
| hx, cx = hidden | |||
| hidden = (hx * mask_h, cx) | |||
| else: | |||
| hidden *= mask_h | |||
| hidden = self.cell(x, hidden) | |||
| output_list.append(hidden[0] if is_lstm else hidden) | |||
| output = torch.stack(output_list, dim=0) | |||
| return output, hidden | |||
| class VarRNNBase(nn.Module): | |||
| def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||
| bias=True, batch_first=False, | |||
| input_dropout=0, hidden_dropout=0, bidirectional=False): | |||
| super(VarRNNBase, self).__init__() | |||
| self.mode = mode | |||
| self.input_size = input_size | |||
| self.hidden_size = hidden_size | |||
| self.num_layers = num_layers | |||
| self.bias = bias | |||
| self.batch_first = batch_first | |||
| self.input_dropout = input_dropout | |||
| self.hidden_dropout = hidden_dropout | |||
| self.bidirectional = bidirectional | |||
| self.num_directions = 2 if bidirectional else 1 | |||
| self._all_cells = nn.ModuleList() | |||
| for layer in range(self.num_layers): | |||
| for direction in range(self.num_directions): | |||
| input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | |||
| cell = Cell(input_size, self.hidden_size, bias) | |||
| self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | |||
| def forward(self, input, hx=None): | |||
| is_packed = isinstance(input, PackedSequence) | |||
| is_lstm = (self.mode == "LSTM") | |||
| if is_packed: | |||
| input, batch_sizes = input | |||
| max_batch_size = int(batch_sizes[0]) | |||
| else: | |||
| batch_sizes = None | |||
| max_batch_size = input.size(0) if self.batch_first else input.size(1) | |||
| if hx is None: | |||
| hx = input.new_zeros(self.num_layers * self.num_directions, | |||
| max_batch_size, self.hidden_size, | |||
| requires_grad=False) | |||
| if is_lstm: | |||
| hx = (hx, hx) | |||
| if self.batch_first: | |||
| input = input.transpose(0, 1) | |||
| hidden_list = [] | |||
| for layer in range(self.num_layers): | |||
| output_list = [] | |||
| for direction in range(self.num_directions): | |||
| input_x = input if direction == 0 else input.flip(0) | |||
| idx = self.num_directions * layer + direction | |||
| cell = self._all_cells[idx] | |||
| output_x, hidden_x = cell(input_x, (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]) | |||
| output_list.append(output_x if direction == 0 else output_x.flip(0)) | |||
| hidden_list.append(hidden_x) | |||
| input = torch.cat(output_list, dim=-1) | |||
| output = input.transpose(0, 1) if self.batch_first else input | |||
| if is_lstm: | |||
| h_list, c_list = zip(*hidden_list) | |||
| hn = torch.stack(h_list, dim=0) | |||
| cn = torch.stack(c_list, dim=0) | |||
| hidden = (hn, cn) | |||
| else: | |||
| hidden = torch.stack(hidden_list, dim=0) | |||
| if is_packed: | |||
| output = PackedSequence(output, batch_sizes) | |||
| return output, hidden | |||
| class VarLSTM(VarRNNBase): | |||
| def __init__(self, *args, **kwargs): | |||
| super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||
| if __name__ == '__main__': | |||
| net = VarLSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True, input_dropout=0.33, hidden_dropout=0.33) | |||
| lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True) | |||
| x = torch.randn(2, 8, 10) | |||
| y, hidden = net(x) | |||
| y0, h0 = lstm(x) | |||
| print(y.shape) | |||
| print(y0.shape) | |||
| print(y) | |||
| print(hidden[0]) | |||
| print(hidden[0].shape) | |||
| print(y0) | |||
| print(h0[0]) | |||
| print(h0[0].shape) | |||