From 1b661e907aa768f63f8ba60c66a9a27c45d686e2 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 12 Aug 2019 01:12:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBertEmbedding=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 80a5b45f..9bedd983 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -294,7 +294,7 @@ class _WordBertModel(nn.Module): word_mask = words.ne(self._pad_index) seq_len = word_mask.sum(dim=-1) batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len - word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask, 0).sum(dim=-1) + word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask.eq(0), 0).sum(dim=-1) max_word_piece_length = word_pieces_lengths.max().item() real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 if max_word_piece_length+2>self._max_position_embeddings: