From 0f4cf3030130af85d29d14b44c0c2d0ef832f9de Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 19 Jun 2019 23:59:40 +0800 Subject: [PATCH] =?UTF-8?q?LSTM=E4=BF=AE=E6=94=B9=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/lstm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 1cc0dec1..10d0e339 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -82,12 +82,12 @@ class LSTM(nn.Module): # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 if self.batch_first: if output.size(1) < max_len: - dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) - output = torch.cat([output, dummy_tensor], 0) - else: - if output.size(0) < max_len: dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) output = torch.cat([output, dummy_tensor], 1) + else: + if output.size(0) < max_len: + dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) + output = torch.cat([output, dummy_tensor], 0) else: output, hx = self.lstm(x, hx) return output, hx