Browse Source

LSTM中修复潜在的DataParallel可能存在的问题, 并且删除init_method参数

tags/v0.4.10
yh 6 years ago
parent
commit
6b9bc007ee
1 changed files with 10 additions and 6 deletions
  1. +10
    -6
      fastNLP/modules/encoder/lstm.py

+ 10
- 6
fastNLP/modules/encoder/lstm.py View File

@@ -19,7 +19,7 @@ class LSTM(nn.Module):
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM`

LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
为1; 且可以应对DataParallel中LSTM的使用问题
为1; 且可以应对DataParallel中LSTM的使用问题

:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度.
@@ -32,13 +32,12 @@ class LSTM(nn.Module):
"""
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None):
bidirectional=False, bias=True):
super(LSTM, self).__init__()
self.batch_first = batch_first
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional)
self.init_param()
initial_parameter(self, initial_method)

def init_param(self):
for name, param in self.named_parameters():
@@ -81,9 +80,14 @@ class LSTM(nn.Module):
else:
output = output[:, unsort_idx]
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
if output.size(1) < max_len:
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1)
if self.batch_first:
if output.size(1) < max_len:
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
output = torch.cat([output, dummy_tensor], 0)
else:
if output.size(0) < max_len:
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1)
else:
output, hx = self.lstm(x, hx)
return output, hx

Loading…
Cancel
Save