diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 38b8daf2..80a5b45f 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -291,9 +291,10 @@ class _WordBertModel(nn.Module): :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size """ batch_size, max_word_len = words.size() - seq_len = words.ne(self._pad_index).sum(dim=-1) + word_mask = words.ne(self._pad_index) + seq_len = word_mask.sum(dim=-1) batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len - word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) + word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask, 0).sum(dim=-1) max_word_piece_length = word_pieces_lengths.max().item() real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 if max_word_piece_length+2>self._max_position_embeddings: @@ -319,8 +320,7 @@ class _WordBertModel(nn.Module): if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) - attn_masks[i, :len(word_pieces_i)+2].fill_(1) - # TODO 截掉长度超过的部分。 + attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,