From 14d048f3406fd05a79e9e61b8e05c410bf8882f0 Mon Sep 17 00:00:00 2001 From: yh Date: Thu, 5 Sep 2019 00:21:46 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbert=20embedding=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 17f6769d..05351cbd 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -420,11 +420,11 @@ class _WordBertModel(nn.Module): if self.pool_method == 'first': batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) - batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) + _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) elif self.pool_method == 'last': batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) - batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) + _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) for l_index, l in enumerate(self.layers): output_layer = bert_outputs[l] @@ -437,12 +437,12 @@ class _WordBertModel(nn.Module): # 从word_piece collapse到word的表示 truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size if self.pool_method == 'first': - tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] + tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp elif self.pool_method == 'last': - tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] + tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp elif self.pool_method == 'max':