From 6b9bc007ee5fbd759591a7c704a33ae732939afe Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 19 Jun 2019 23:49:01 +0800 Subject: [PATCH] =?UTF-8?q?LSTM=E4=B8=AD=E4=BF=AE=E5=A4=8D=E6=BD=9C?= =?UTF-8?q?=E5=9C=A8=E7=9A=84DataParallel=E5=8F=AF=E8=83=BD=E5=AD=98?= =?UTF-8?q?=E5=9C=A8=E7=9A=84=E9=97=AE=E9=A2=98,=20=E5=B9=B6=E4=B8=94?= =?UTF-8?q?=E5=88=A0=E9=99=A4init=5Fmethod=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/lstm.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 2966426a..1cc0dec1 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -19,7 +19,7 @@ class LSTM(nn.Module): 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 - 为1; 且可以应对DataParallel中LSTM的使用问题 + 为1; 且可以应对DataParallel中LSTM的使用问题。 :param input_size: 输入 `x` 的特征维度 :param hidden_size: 隐状态 `h` 的特征维度. @@ -32,13 +32,12 @@ class LSTM(nn.Module): """ def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, - bidirectional=False, bias=True, initial_method=None): + bidirectional=False, bias=True): super(LSTM, self).__init__() self.batch_first = batch_first self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) self.init_param() - initial_parameter(self, initial_method) def init_param(self): for name, param in self.named_parameters(): @@ -81,9 +80,14 @@ class LSTM(nn.Module): else: output = output[:, unsort_idx] # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 - if output.size(1) < max_len: - dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) - output = torch.cat([output, dummy_tensor], 1) + if self.batch_first: + if output.size(1) < max_len: + dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) + output = torch.cat([output, dummy_tensor], 0) + else: + if output.size(0) < max_len: + dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) + output = torch.cat([output, dummy_tensor], 1) else: output, hx = self.lstm(x, hx) return output, hx