Browse Source

update var_rnn

tags/v0.2.0
yunfan 6 years ago
parent
commit
a51ede46f7
1 changed files with 131 additions and 1 deletions
  1. +131
    -1
      fastNLP/modules/encoder/variational_rnn.py

+ 131
- 1
fastNLP/modules/encoder/variational_rnn.py View File

@@ -5,8 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import PackedSequence

from fastNLP.modules.utils import initial_parameter
# from fastNLP.modules.utils import initial_parameter

def default_initializer(hidden_size):
stdv = 1.0 / math.sqrt(hidden_size)
@@ -383,3 +384,132 @@ class VarFastLSTMCell(VarRNNCellBase):
hy = outgate * F.tanh(cy)

return hy, cy


class VarRnnCellWrapper(nn.Module):
def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
self.hidden_size = hidden_size
self.input_p = input_p
self.hidden_p = hidden_p

def forward(self, input, hidden):
"""
:param input: [seq_len, batch_size, input_size]
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size]
for other RNN, h_0, [batch_size, hidden_size]

:return output: [seq_len, bacth_size, hidden_size]
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
"""
is_lstm = isinstance(hidden, tuple)
_, batch_size, input_size = input.shape
mask_x = input.new_ones((batch_size, input_size))
mask_h = input.new_ones((batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_p, training=self.training, inplace=True)
nn.functional.dropout(mask_h, p=self.hidden_p, training=self.training, inplace=True)

input_x = input * mask_x.unsqueeze(0)
output_list = []
for x in input_x:
if is_lstm:
hx, cx = hidden
hidden = (hx * mask_h, cx)
else:
hidden *= mask_h
hidden = self.cell(x, hidden)
output_list.append(hidden[0] if is_lstm else hidden)
output = torch.stack(output_list, dim=0)
return output, hidden


class VarRNNBase(nn.Module):
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
super(VarRNNBase, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.input_dropout = input_dropout
self.hidden_dropout = hidden_dropout
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
self._all_cells = nn.ModuleList()
for layer in range(self.num_layers):
for direction in range(self.num_directions):
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
cell = Cell(input_size, self.hidden_size, bias)
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout))

def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
is_lstm = (self.mode == "LSTM")
if is_packed:
input, batch_sizes = input
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)

if hx is None:
hx = input.new_zeros(self.num_layers * self.num_directions,
max_batch_size, self.hidden_size,
requires_grad=False)
if is_lstm:
hx = (hx, hx)

if self.batch_first:
input = input.transpose(0, 1)

hidden_list = []
for layer in range(self.num_layers):
output_list = []
for direction in range(self.num_directions):
input_x = input if direction == 0 else input.flip(0)
idx = self.num_directions * layer + direction
cell = self._all_cells[idx]
output_x, hidden_x = cell(input_x, (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx])
output_list.append(output_x if direction == 0 else output_x.flip(0))
hidden_list.append(hidden_x)
input = torch.cat(output_list, dim=-1)
output = input.transpose(0, 1) if self.batch_first else input
if is_lstm:
h_list, c_list = zip(*hidden_list)
hn = torch.stack(h_list, dim=0)
cn = torch.stack(c_list, dim=0)
hidden = (hn, cn)
else:
hidden = torch.stack(hidden_list, dim=0)
if is_packed:
output = PackedSequence(output, batch_sizes)

return output, hidden


class VarLSTM(VarRNNBase):
def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)


if __name__ == '__main__':
net = VarLSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True, input_dropout=0.33, hidden_dropout=0.33)
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True)
x = torch.randn(2, 8, 10)
y, hidden = net(x)
y0, h0 = lstm(x)
print(y.shape)
print(y0.shape)
print(y)
print(hidden[0])
print(hidden[0].shape)
print(y0)
print(h0[0])
print(h0[0].shape)

Loading…
Cancel
Save