diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 1cc0dec1..10d0e339 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -82,12 +82,12 @@ class LSTM(nn.Module): # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 if self.batch_first: if output.size(1) < max_len: - dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) - output = torch.cat([output, dummy_tensor], 0) - else: - if output.size(0) < max_len: dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) output = torch.cat([output, dummy_tensor], 1) + else: + if output.size(0) < max_len: + dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) + output = torch.cat([output, dummy_tensor], 0) else: output, hx = self.lstm(x, hx) return output, hx