Browse Source

再次修改elmo的dataparallel问题

tags/v0.4.10
yh 6 years ago
parent
commit
1167d3b587
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      fastNLP/modules/encoder/_elmo.py

+ 3
- 3
fastNLP/modules/encoder/_elmo.py View File

@@ -761,11 +761,11 @@ class _ElmoModel(nn.Module):
token_embedding = self.token_embedder(expanded_words, chars)
if self.config['encoder']['name'] == 'elmo':
encoder_output = self.encoder(token_embedding, seq_len)
if encoder_output.size(2) < max_len:
if encoder_output.size(2) < max_len+2:
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size,
max_len - encoder_output.size(2), encoder_output.size(-1))
max_len + 2 - encoder_output.size(2), encoder_output.size(-1))
encoder_output = torch.cat([encoder_output, dummy_tensor], 2)
sz = encoder_output.size() # batch_size, max_len, hidden_size
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat([token_embedding, encoder_output], dim=0)
elif self.config['encoder']['name'] == 'lstm':


Loading…
Cancel
Save