From c4e131a0c551af3d5d22b3d53b673167eba7b613 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 19 Jun 2019 22:19:41 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E6=94=B9ELMO?= =?UTF-8?q?=E4=B8=8ELSTM=20DataParallel=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/_elmo.py | 5 +++-- fastNLP/modules/encoder/lstm.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 7fa29201..11feead6 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -762,8 +762,9 @@ class _ElmoModel(nn.Module): if self.config['encoder']['name'] == 'elmo': encoder_output = self.encoder(token_embedding, seq_len) if encoder_output.size(2) < max_len: - dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - encoder_output.size(2), encoder_output.size(-1))) - encoder_output = torch.cat([encoder_output, dummy_tensor], 1) + dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, + max_len - encoder_output.size(2), encoder_output.size(-1)) + encoder_output = torch.cat([encoder_output, dummy_tensor], 2) sz = encoder_output.size() # batch_size, max_len, hidden_size token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) encoder_output = torch.cat([token_embedding, encoder_output], dim=0) diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 0118d6d7..2966426a 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -82,7 +82,7 @@ class LSTM(nn.Module): output = output[:, unsort_idx] # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 if output.size(1) < max_len: - dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - output.size(1), output.size(-1))) + dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) output = torch.cat([output, dummy_tensor], 1) else: output, hx = self.lstm(x, hx)