Browse Source

重新修改ELMO与LSTM DataParallel的问题

tags/v0.4.10
yh 6 years ago
parent
commit
c4e131a0c5
2 changed files with 4 additions and 3 deletions
  1. +3
    -2
      fastNLP/modules/encoder/_elmo.py
  2. +1
    -1
      fastNLP/modules/encoder/lstm.py

+ 3
- 2
fastNLP/modules/encoder/_elmo.py View File

@@ -762,8 +762,9 @@ class _ElmoModel(nn.Module):
if self.config['encoder']['name'] == 'elmo':
encoder_output = self.encoder(token_embedding, seq_len)
if encoder_output.size(2) < max_len:
dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - encoder_output.size(2), encoder_output.size(-1)))
encoder_output = torch.cat([encoder_output, dummy_tensor], 1)
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size,
max_len - encoder_output.size(2), encoder_output.size(-1))
encoder_output = torch.cat([encoder_output, dummy_tensor], 2)
sz = encoder_output.size() # batch_size, max_len, hidden_size
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat([token_embedding, encoder_output], dim=0)


+ 1
- 1
fastNLP/modules/encoder/lstm.py View File

@@ -82,7 +82,7 @@ class LSTM(nn.Module):
output = output[:, unsort_idx]
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
if output.size(1) < max_len:
dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - output.size(1), output.size(-1)))
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1)
else:
output, hx = self.lstm(x, hx)


Loading…
Cancel
Save