|
|
@@ -761,11 +761,11 @@ class _ElmoModel(nn.Module): |
|
|
|
token_embedding = self.token_embedder(expanded_words, chars) |
|
|
|
if self.config['encoder']['name'] == 'elmo': |
|
|
|
encoder_output = self.encoder(token_embedding, seq_len) |
|
|
|
if encoder_output.size(2) < max_len: |
|
|
|
if encoder_output.size(2) < max_len+2: |
|
|
|
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, |
|
|
|
max_len - encoder_output.size(2), encoder_output.size(-1)) |
|
|
|
max_len + 2 - encoder_output.size(2), encoder_output.size(-1)) |
|
|
|
encoder_output = torch.cat([encoder_output, dummy_tensor], 2) |
|
|
|
sz = encoder_output.size() # batch_size, max_len, hidden_size |
|
|
|
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size |
|
|
|
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) |
|
|
|
encoder_output = torch.cat([token_embedding, encoder_output], dim=0) |
|
|
|
elif self.config['encoder']['name'] == 'lstm': |
|
|
|