|
|
@@ -11,7 +11,8 @@ except ImportError: |
|
|
|
def flip(x, dims): |
|
|
|
indices = [slice(None)] * x.dim() |
|
|
|
for dim in dims: |
|
|
|
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) |
|
|
|
indices[dim] = torch.arange( |
|
|
|
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) |
|
|
|
return x[tuple(indices)] |
|
|
|
|
|
|
|
from ..utils import initial_parameter |
|
|
@@ -27,14 +28,14 @@ class VarRnnCellWrapper(nn.Module): |
|
|
|
""" |
|
|
|
Wrapper for normal RNN Cells, make it support variational dropout |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, cell, hidden_size, input_p, hidden_p): |
|
|
|
super(VarRnnCellWrapper, self).__init__() |
|
|
|
self.cell = cell |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.input_p = input_p |
|
|
|
self.hidden_p = hidden_p |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): |
|
|
|
""" |
|
|
|
:param PackedSequence input_x: [seq_len, batch_size, input_size] |
|
|
@@ -46,13 +47,13 @@ class VarRnnCellWrapper(nn.Module): |
|
|
|
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] |
|
|
|
for other RNN, h_n, [batch_size, hidden_size] |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def get_hi(hi, h0, size): |
|
|
|
h0_size = size - hi.size(0) |
|
|
|
if h0_size > 0: |
|
|
|
return torch.cat([hi, h0[:h0_size]], dim=0) |
|
|
|
return hi[:size] |
|
|
|
|
|
|
|
|
|
|
|
is_lstm = isinstance(hidden, tuple) |
|
|
|
input, batch_sizes = input_x.data, input_x.batch_sizes |
|
|
|
output = [] |
|
|
@@ -63,7 +64,7 @@ class VarRnnCellWrapper(nn.Module): |
|
|
|
else: |
|
|
|
batch_iter = batch_sizes |
|
|
|
idx = 0 |
|
|
|
|
|
|
|
|
|
|
|
if is_lstm: |
|
|
|
hn = (hidden[0].clone(), hidden[1].clone()) |
|
|
|
else: |
|
|
@@ -79,7 +80,8 @@ class VarRnnCellWrapper(nn.Module): |
|
|
|
mask_hi = mask_h[:size] |
|
|
|
if is_lstm: |
|
|
|
hx, cx = hi |
|
|
|
hi = (get_hi(hx, hidden[0], size) * mask_hi, get_hi(cx, hidden[1], size)) |
|
|
|
hi = (get_hi(hx, hidden[0], size) * |
|
|
|
mask_hi, get_hi(cx, hidden[1], size)) |
|
|
|
hi = cell(input_i, hi) |
|
|
|
hn[0][:size] = hi[0] |
|
|
|
hn[1][:size] = hi[1] |
|
|
@@ -89,7 +91,7 @@ class VarRnnCellWrapper(nn.Module): |
|
|
|
hi = cell(input_i, hi) |
|
|
|
hn[:size] = hi |
|
|
|
output.append(hi) |
|
|
|
|
|
|
|
|
|
|
|
if is_reversed: |
|
|
|
output = list(reversed(output)) |
|
|
|
output = torch.cat(output, dim=0) |
|
|
@@ -99,7 +101,7 @@ class VarRnnCellWrapper(nn.Module): |
|
|
|
class VarRNNBase(nn.Module): |
|
|
|
""" |
|
|
|
Variational Dropout RNN 实现. |
|
|
|
|
|
|
|
|
|
|
|
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) |
|
|
|
https://arxiv.org/abs/1512.05287`. |
|
|
|
|
|
|
@@ -115,7 +117,7 @@ class VarRNNBase(nn.Module): |
|
|
|
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 |
|
|
|
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, |
|
|
|
bias=True, batch_first=False, |
|
|
|
input_dropout=0, hidden_dropout=0, bidirectional=False): |
|
|
@@ -135,18 +137,20 @@ class VarRNNBase(nn.Module): |
|
|
|
for direction in range(self.num_directions): |
|
|
|
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions |
|
|
|
cell = Cell(input_size, self.hidden_size, bias) |
|
|
|
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) |
|
|
|
self._all_cells.append(VarRnnCellWrapper( |
|
|
|
cell, self.hidden_size, input_dropout, hidden_dropout)) |
|
|
|
initial_parameter(self) |
|
|
|
self.is_lstm = (self.mode == "LSTM") |
|
|
|
|
|
|
|
|
|
|
|
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): |
|
|
|
is_lstm = self.is_lstm |
|
|
|
idx = self.num_directions * n_layer + n_direction |
|
|
|
cell = self._all_cells[idx] |
|
|
|
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] |
|
|
|
output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) |
|
|
|
output_x, hidden_x = cell( |
|
|
|
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) |
|
|
|
return output_x, hidden_x |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, hx=None): |
|
|
|
""" |
|
|
|
|
|
|
@@ -160,31 +164,38 @@ class VarRNNBase(nn.Module): |
|
|
|
if not is_packed: |
|
|
|
seq_len = x.size(1) if self.batch_first else x.size(0) |
|
|
|
max_batch_size = x.size(0) if self.batch_first else x.size(1) |
|
|
|
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) |
|
|
|
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) |
|
|
|
seq_lens = torch.LongTensor( |
|
|
|
[seq_len for _ in range(max_batch_size)]) |
|
|
|
x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) |
|
|
|
else: |
|
|
|
max_batch_size = int(input.batch_sizes[0]) |
|
|
|
input, batch_sizes = input.data, input.batch_sizes |
|
|
|
|
|
|
|
max_batch_size = int(x.batch_sizes[0]) |
|
|
|
x, batch_sizes = x.data, x.batch_sizes |
|
|
|
|
|
|
|
if hx is None: |
|
|
|
hx = x.new_zeros(self.num_layers * self.num_directions, |
|
|
|
max_batch_size, self.hidden_size, requires_grad=True) |
|
|
|
if is_lstm: |
|
|
|
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) |
|
|
|
|
|
|
|
|
|
|
|
mask_x = x.new_ones((max_batch_size, self.input_size)) |
|
|
|
mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions)) |
|
|
|
mask_out = x.new_ones( |
|
|
|
(max_batch_size, self.hidden_size * self.num_directions)) |
|
|
|
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) |
|
|
|
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) |
|
|
|
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) |
|
|
|
|
|
|
|
hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) |
|
|
|
nn.functional.dropout(mask_x, p=self.input_dropout, |
|
|
|
training=self.training, inplace=True) |
|
|
|
nn.functional.dropout(mask_out, p=self.hidden_dropout, |
|
|
|
training=self.training, inplace=True) |
|
|
|
|
|
|
|
hidden = x.new_zeros( |
|
|
|
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) |
|
|
|
if is_lstm: |
|
|
|
cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) |
|
|
|
cellstate = x.new_zeros( |
|
|
|
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) |
|
|
|
for layer in range(self.num_layers): |
|
|
|
output_list = [] |
|
|
|
input_seq = PackedSequence(x, batch_sizes) |
|
|
|
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) |
|
|
|
mask_h = nn.functional.dropout( |
|
|
|
mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) |
|
|
|
for direction in range(self.num_directions): |
|
|
|
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, |
|
|
|
mask_x if layer == 0 else mask_out, mask_h) |
|
|
@@ -196,16 +207,16 @@ class VarRNNBase(nn.Module): |
|
|
|
else: |
|
|
|
hidden[idx] = hidden_x |
|
|
|
x = torch.cat(output_list, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
if is_lstm: |
|
|
|
hidden = (hidden, cellstate) |
|
|
|
|
|
|
|
|
|
|
|
if is_packed: |
|
|
|
output = PackedSequence(x, batch_sizes) |
|
|
|
else: |
|
|
|
x = PackedSequence(x, batch_sizes) |
|
|
|
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) |
|
|
|
|
|
|
|
|
|
|
|
return output, hidden |
|
|
|
|
|
|
|
|
|
|
@@ -225,10 +236,11 @@ class VarLSTM(VarRNNBase): |
|
|
|
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 |
|
|
|
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) |
|
|
|
|
|
|
|
super(VarLSTM, self).__init__( |
|
|
|
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) |
|
|
|
|
|
|
|
def forward(self, x, hx=None): |
|
|
|
return super(VarLSTM, self).forward(x, hx) |
|
|
|
|
|
|
@@ -249,10 +261,11 @@ class VarRNN(VarRNNBase): |
|
|
|
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 |
|
|
|
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) |
|
|
|
|
|
|
|
super(VarRNN, self).__init__( |
|
|
|
mode="RNN", Cell=nn.RNNCell, *args, **kwargs) |
|
|
|
|
|
|
|
def forward(self, x, hx=None): |
|
|
|
return super(VarRNN, self).forward(x, hx) |
|
|
|
|
|
|
@@ -273,9 +286,10 @@ class VarGRU(VarRNNBase): |
|
|
|
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 |
|
|
|
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) |
|
|
|
|
|
|
|
super(VarGRU, self).__init__( |
|
|
|
mode="GRU", Cell=nn.GRUCell, *args, **kwargs) |
|
|
|
|
|
|
|
def forward(self, x, hx=None): |
|
|
|
return super(VarGRU, self).forward(x, hx) |