From b0c50f7299f4439f1b015ad74c2aa291e0dd798f Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 12 Aug 2019 01:18:24 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBertEmbedding=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 9bedd983..963ba04c 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -311,7 +311,6 @@ class _WordBertModel(nn.Module): word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) word_pieces[:, 0].fill_(self._cls_index) batch_indexes = torch.arange(batch_size).to(words) - word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index attn_masks = torch.zeros_like(word_pieces) # 1. 获取words的word_pieces的id,以及对应的span范围 word_indexes = words.tolist() @@ -320,6 +319,7 @@ class _WordBertModel(nn.Module): if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) + word_pieces[i, len(word_pieces_i)+1] = self._sep_index # 补上sep attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]