Browse Source

LSTM修改错误

tags/v0.4.10
yh 6 years ago
parent
commit
0f4cf30301
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      fastNLP/modules/encoder/lstm.py

+ 4
- 4
fastNLP/modules/encoder/lstm.py View File

@@ -82,12 +82,12 @@ class LSTM(nn.Module):
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
if self.batch_first: if self.batch_first:
if output.size(1) < max_len: if output.size(1) < max_len:
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
output = torch.cat([output, dummy_tensor], 0)
else:
if output.size(0) < max_len:
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1) output = torch.cat([output, dummy_tensor], 1)
else:
if output.size(0) < max_len:
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
output = torch.cat([output, dummy_tensor], 0)
else: else:
output, hx = self.lstm(x, hx) output, hx = self.lstm(x, hx)
return output, hx return output, hx

Loading…
Cancel
Save