|
|
@@ -5,8 +5,9 @@ import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend |
|
|
|
from torch.nn.parameter import Parameter |
|
|
|
from torch.nn.utils.rnn import PackedSequence |
|
|
|
|
|
|
|
from fastNLP.modules.utils import initial_parameter |
|
|
|
# from fastNLP.modules.utils import initial_parameter |
|
|
|
|
|
|
|
def default_initializer(hidden_size): |
|
|
|
stdv = 1.0 / math.sqrt(hidden_size) |
|
|
@@ -383,3 +384,132 @@ class VarFastLSTMCell(VarRNNCellBase): |
|
|
|
hy = outgate * F.tanh(cy) |
|
|
|
|
|
|
|
return hy, cy |
|
|
|
|
|
|
|
|
|
|
|
class VarRnnCellWrapper(nn.Module): |
|
|
|
def __init__(self, cell, hidden_size, input_p, hidden_p): |
|
|
|
super(VarRnnCellWrapper, self).__init__() |
|
|
|
self.cell = cell |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.input_p = input_p |
|
|
|
self.hidden_p = hidden_p |
|
|
|
|
|
|
|
def forward(self, input, hidden): |
|
|
|
""" |
|
|
|
:param input: [seq_len, batch_size, input_size] |
|
|
|
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] |
|
|
|
for other RNN, h_0, [batch_size, hidden_size] |
|
|
|
|
|
|
|
:return output: [seq_len, bacth_size, hidden_size] |
|
|
|
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] |
|
|
|
for other RNN, h_n, [batch_size, hidden_size] |
|
|
|
""" |
|
|
|
is_lstm = isinstance(hidden, tuple) |
|
|
|
_, batch_size, input_size = input.shape |
|
|
|
mask_x = input.new_ones((batch_size, input_size)) |
|
|
|
mask_h = input.new_ones((batch_size, self.hidden_size)) |
|
|
|
nn.functional.dropout(mask_x, p=self.input_p, training=self.training, inplace=True) |
|
|
|
nn.functional.dropout(mask_h, p=self.hidden_p, training=self.training, inplace=True) |
|
|
|
|
|
|
|
input_x = input * mask_x.unsqueeze(0) |
|
|
|
output_list = [] |
|
|
|
for x in input_x: |
|
|
|
if is_lstm: |
|
|
|
hx, cx = hidden |
|
|
|
hidden = (hx * mask_h, cx) |
|
|
|
else: |
|
|
|
hidden *= mask_h |
|
|
|
hidden = self.cell(x, hidden) |
|
|
|
output_list.append(hidden[0] if is_lstm else hidden) |
|
|
|
output = torch.stack(output_list, dim=0) |
|
|
|
return output, hidden |
|
|
|
|
|
|
|
|
|
|
|
class VarRNNBase(nn.Module): |
|
|
|
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, |
|
|
|
bias=True, batch_first=False, |
|
|
|
input_dropout=0, hidden_dropout=0, bidirectional=False): |
|
|
|
super(VarRNNBase, self).__init__() |
|
|
|
self.mode = mode |
|
|
|
self.input_size = input_size |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.num_layers = num_layers |
|
|
|
self.bias = bias |
|
|
|
self.batch_first = batch_first |
|
|
|
self.input_dropout = input_dropout |
|
|
|
self.hidden_dropout = hidden_dropout |
|
|
|
self.bidirectional = bidirectional |
|
|
|
self.num_directions = 2 if bidirectional else 1 |
|
|
|
self._all_cells = nn.ModuleList() |
|
|
|
for layer in range(self.num_layers): |
|
|
|
for direction in range(self.num_directions): |
|
|
|
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions |
|
|
|
cell = Cell(input_size, self.hidden_size, bias) |
|
|
|
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) |
|
|
|
|
|
|
|
def forward(self, input, hx=None): |
|
|
|
is_packed = isinstance(input, PackedSequence) |
|
|
|
is_lstm = (self.mode == "LSTM") |
|
|
|
if is_packed: |
|
|
|
input, batch_sizes = input |
|
|
|
max_batch_size = int(batch_sizes[0]) |
|
|
|
else: |
|
|
|
batch_sizes = None |
|
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1) |
|
|
|
|
|
|
|
if hx is None: |
|
|
|
hx = input.new_zeros(self.num_layers * self.num_directions, |
|
|
|
max_batch_size, self.hidden_size, |
|
|
|
requires_grad=False) |
|
|
|
if is_lstm: |
|
|
|
hx = (hx, hx) |
|
|
|
|
|
|
|
if self.batch_first: |
|
|
|
input = input.transpose(0, 1) |
|
|
|
|
|
|
|
hidden_list = [] |
|
|
|
for layer in range(self.num_layers): |
|
|
|
output_list = [] |
|
|
|
for direction in range(self.num_directions): |
|
|
|
input_x = input if direction == 0 else input.flip(0) |
|
|
|
idx = self.num_directions * layer + direction |
|
|
|
cell = self._all_cells[idx] |
|
|
|
output_x, hidden_x = cell(input_x, (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]) |
|
|
|
output_list.append(output_x if direction == 0 else output_x.flip(0)) |
|
|
|
hidden_list.append(hidden_x) |
|
|
|
input = torch.cat(output_list, dim=-1) |
|
|
|
|
|
|
|
output = input.transpose(0, 1) if self.batch_first else input |
|
|
|
if is_lstm: |
|
|
|
h_list, c_list = zip(*hidden_list) |
|
|
|
hn = torch.stack(h_list, dim=0) |
|
|
|
cn = torch.stack(c_list, dim=0) |
|
|
|
hidden = (hn, cn) |
|
|
|
else: |
|
|
|
hidden = torch.stack(hidden_list, dim=0) |
|
|
|
|
|
|
|
if is_packed: |
|
|
|
output = PackedSequence(output, batch_sizes) |
|
|
|
|
|
|
|
return output, hidden |
|
|
|
|
|
|
|
|
|
|
|
class VarLSTM(VarRNNBase): |
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
net = VarLSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True, input_dropout=0.33, hidden_dropout=0.33) |
|
|
|
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True) |
|
|
|
x = torch.randn(2, 8, 10) |
|
|
|
y, hidden = net(x) |
|
|
|
y0, h0 = lstm(x) |
|
|
|
print(y.shape) |
|
|
|
print(y0.shape) |
|
|
|
print(y) |
|
|
|
print(hidden[0]) |
|
|
|
print(hidden[0].shape) |
|
|
|
print(y0) |
|
|
|
print(h0[0]) |
|
|
|
print(h0[0].shape) |