From 2098a81f2fad4a11c53f6347f41670c71f06bdb9 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 12 Aug 2019 00:58:57 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBertEmbedding=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 38b8daf2..80a5b45f 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -291,9 +291,10 @@ class _WordBertModel(nn.Module): :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size """ batch_size, max_word_len = words.size() - seq_len = words.ne(self._pad_index).sum(dim=-1) + word_mask = words.ne(self._pad_index) + seq_len = word_mask.sum(dim=-1) batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len - word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) + word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask, 0).sum(dim=-1) max_word_piece_length = word_pieces_lengths.max().item() real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 if max_word_piece_length+2>self._max_position_embeddings: @@ -319,8 +320,7 @@ class _WordBertModel(nn.Module): if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) - attn_masks[i, :len(word_pieces_i)+2].fill_(1) - # TODO 截掉长度超过的部分。 + attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,