diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 11feead6..ab43d32f 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -761,11 +761,11 @@ class _ElmoModel(nn.Module): token_embedding = self.token_embedder(expanded_words, chars) if self.config['encoder']['name'] == 'elmo': encoder_output = self.encoder(token_embedding, seq_len) - if encoder_output.size(2) < max_len: + if encoder_output.size(2) < max_len+2: dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, - max_len - encoder_output.size(2), encoder_output.size(-1)) + max_len + 2 - encoder_output.size(2), encoder_output.size(-1)) encoder_output = torch.cat([encoder_output, dummy_tensor], 2) - sz = encoder_output.size() # batch_size, max_len, hidden_size + sz = encoder_output.size() # 2, batch_size, max_len, hidden_size token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) encoder_output = torch.cat([token_embedding, encoder_output], dim=0) elif self.config['encoder']['name'] == 'lstm':