|
|
@@ -19,7 +19,7 @@ class LSTM(nn.Module): |
|
|
|
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` |
|
|
|
|
|
|
|
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 |
|
|
|
为1; 且可以应对DataParallel中LSTM的使用问题 |
|
|
|
为1; 且可以应对DataParallel中LSTM的使用问题。 |
|
|
|
|
|
|
|
:param input_size: 输入 `x` 的特征维度 |
|
|
|
:param hidden_size: 隐状态 `h` 的特征维度. |
|
|
@@ -32,13 +32,12 @@ class LSTM(nn.Module): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, |
|
|
|
bidirectional=False, bias=True, initial_method=None): |
|
|
|
bidirectional=False, bias=True): |
|
|
|
super(LSTM, self).__init__() |
|
|
|
self.batch_first = batch_first |
|
|
|
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, |
|
|
|
dropout=dropout, bidirectional=bidirectional) |
|
|
|
self.init_param() |
|
|
|
initial_parameter(self, initial_method) |
|
|
|
|
|
|
|
def init_param(self): |
|
|
|
for name, param in self.named_parameters(): |
|
|
@@ -81,9 +80,14 @@ class LSTM(nn.Module): |
|
|
|
else: |
|
|
|
output = output[:, unsort_idx] |
|
|
|
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 |
|
|
|
if output.size(1) < max_len: |
|
|
|
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) |
|
|
|
output = torch.cat([output, dummy_tensor], 1) |
|
|
|
if self.batch_first: |
|
|
|
if output.size(1) < max_len: |
|
|
|
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) |
|
|
|
output = torch.cat([output, dummy_tensor], 0) |
|
|
|
else: |
|
|
|
if output.size(0) < max_len: |
|
|
|
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) |
|
|
|
output = torch.cat([output, dummy_tensor], 1) |
|
|
|
else: |
|
|
|
output, hx = self.lstm(x, hx) |
|
|
|
return output, hx |