From 1167d3b58788aa675e914aed5980f7956a67e713 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 19 Jun 2019 22:35:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=86=8D=E6=AC=A1=E4=BF=AE=E6=94=B9elmo?= =?UTF-8?q?=E7=9A=84dataparallel=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/_elmo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 11feead6..ab43d32f 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -761,11 +761,11 @@ class _ElmoModel(nn.Module): token_embedding = self.token_embedder(expanded_words, chars) if self.config['encoder']['name'] == 'elmo': encoder_output = self.encoder(token_embedding, seq_len) - if encoder_output.size(2) < max_len: + if encoder_output.size(2) < max_len+2: dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, - max_len - encoder_output.size(2), encoder_output.size(-1)) + max_len + 2 - encoder_output.size(2), encoder_output.size(-1)) encoder_output = torch.cat([encoder_output, dummy_tensor], 2) - sz = encoder_output.size() # batch_size, max_len, hidden_size + sz = encoder_output.size() # 2, batch_size, max_len, hidden_size token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) encoder_output = torch.cat([token_embedding, encoder_output], dim=0) elif self.config['encoder']['name'] == 'lstm':