diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index afba9d13..1fadd491 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -278,7 +278,7 @@ class _WordBertModel(nn.Module): print("Found(Or seg into word pieces) {} words out of {}.".format(found_count, len(vocab))) self._cls_index = self.tokenzier.vocab['[CLS]'] self._sep_index = self.tokenzier.vocab['[SEP]'] - self._pad_index = vocab.padding_idx + self._word_pad_index = vocab.padding_idx self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece self.word_to_wordpieces = np.array(word_to_wordpieces) self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) @@ -291,23 +291,22 @@ class _WordBertModel(nn.Module): :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size """ batch_size, max_word_len = words.size() - word_mask = words.ne(self._pad_index) + word_mask = words.ne(self._word_pad_index) # 为1的地方有word seq_len = word_mask.sum(dim=-1) batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len - word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask.eq(0), 0).sum(dim=-1) - max_word_piece_length = word_pieces_lengths.max().item() - real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 - if max_word_piece_length+2>self._max_position_embeddings: + word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask.eq(0), 0).sum(dim=-1) # batch_size + word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) + if word_piece_length+2>self._max_position_embeddings: if self.auto_truncate: word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, self._max_position_embeddings-2) - max_word_piece_length = self._max_position_embeddings-2 else: raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") # +2是由于需要加入[CLS]与[SEP] - word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) + word_pieces = words.new_full((batch_size, min(word_piece_length+2, self._max_position_embeddings)), + fill_value=self._wordpiece_pad_index) attn_masks = torch.zeros_like(word_pieces) # 1. 获取words的word_pieces的id,以及对应的span范围 word_indexes = words.tolist() @@ -325,7 +324,7 @@ class _WordBertModel(nn.Module): # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, output_all_encoded_layers=True) - # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size + # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size if self.include_cls_sep: outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, @@ -339,9 +338,10 @@ class _WordBertModel(nn.Module): batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len for l_index, l in enumerate(self.layers): output_layer = bert_outputs[l] - if real_max_word_piece_length > max_word_piece_length: # 如果实际上是截取出来的 + real_word_piece_length = output_layer.size(1) - 2 + if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 paddings = output_layer.new_zeros(batch_size, - real_max_word_piece_length-max_word_piece_length, + word_piece_length-real_word_piece_length, output_layer.size(2)) output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() # 从word_piece collapse到word的表示