|
@@ -82,12 +82,12 @@ class LSTM(nn.Module): |
|
|
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 |
|
|
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 |
|
|
if self.batch_first: |
|
|
if self.batch_first: |
|
|
if output.size(1) < max_len: |
|
|
if output.size(1) < max_len: |
|
|
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) |
|
|
|
|
|
output = torch.cat([output, dummy_tensor], 0) |
|
|
|
|
|
else: |
|
|
|
|
|
if output.size(0) < max_len: |
|
|
|
|
|
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) |
|
|
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) |
|
|
output = torch.cat([output, dummy_tensor], 1) |
|
|
output = torch.cat([output, dummy_tensor], 1) |
|
|
|
|
|
else: |
|
|
|
|
|
if output.size(0) < max_len: |
|
|
|
|
|
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) |
|
|
|
|
|
output = torch.cat([output, dummy_tensor], 0) |
|
|
else: |
|
|
else: |
|
|
output, hx = self.lstm(x, hx) |
|
|
output, hx = self.lstm(x, hx) |
|
|
return output, hx |
|
|
return output, hx |