diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 660e803e..d13be767 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -80,6 +80,8 @@ class BertEmbedding(ContextualEmbedding): :param kwargs: bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 建议设置为True。 + int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 + bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) """ super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) @@ -92,25 +94,27 @@ class BertEmbedding(ContextualEmbedding): " faster speed.") warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" " faster speed.") - + self._word_sep_index = None if '[SEP]' in vocab: self._word_sep_index = vocab['[SEP]'] only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) - + truncate_embed = kwargs.get('truncate_embed', True) + min_freq = kwargs.get('min_freq', 2) + self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, - pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2, - only_use_pretrain_bpe=only_use_pretrain_bpe) + pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, + only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) self._sep_index = self.model._sep_index self._cls_index = self.model._cls_index self.requires_grad = requires_grad self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size - + def _delete_model_weights(self): del self.model - + def forward(self, words): r""" 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 @@ -125,9 +129,9 @@ class BertEmbedding(ContextualEmbedding): return self.dropout(outputs) outputs = self.model(words) outputs = torch.cat([*outputs], dim=-1) - + return self.dropout(outputs) - + def drop_word(self, words): r""" 按照设定随机将words设置为unknown_index。 @@ -167,11 +171,11 @@ class BertWordPieceEncoder(nn.Module): multi-base-uncased: multilingual uncased """ - + def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, word_dropout=0, dropout=0, requires_grad: bool = True): r""" - + :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 :param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。 @@ -180,7 +184,7 @@ class BertWordPieceEncoder(nn.Module): :param bool requires_grad: 是否需要gradient。 """ super().__init__() - + self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) self._sep_index = self.model._sep_index self._cls_index = self.model._cls_index @@ -190,19 +194,19 @@ class BertWordPieceEncoder(nn.Module): self.requires_grad = requires_grad self.word_dropout = word_dropout self.dropout_layer = nn.Dropout(dropout) - + @property def embed_size(self): return self._embed_size - + @property def embedding_dim(self): return self._embed_size - + @property def num_embedding(self): return self.model.encoder.config.vocab_size - + def index_datasets(self, *datasets, field_name, add_cls_sep=True): r""" 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 @@ -214,7 +218,7 @@ class BertWordPieceEncoder(nn.Module): :return: """ self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) - + def forward(self, word_pieces, token_type_ids=None): r""" 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 @@ -231,13 +235,13 @@ class BertWordPieceEncoder(nn.Module): token_type_ids = sep_mask_cumsum.fmod(2) if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 token_type_ids = token_type_ids.eq(0).long() - + word_pieces = self.drop_word(word_pieces) outputs = self.model(word_pieces, token_type_ids) outputs = torch.cat([*outputs], dim=-1) - + return self.dropout_layer(outputs) - + def drop_word(self, words): r""" 按照设定随机将words设置为unknown_index。 @@ -261,9 +265,9 @@ class BertWordPieceEncoder(nn.Module): class _WordBertModel(nn.Module): def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, - only_use_pretrain_bpe=False): + only_use_pretrain_bpe=False, truncate_embed=True): super().__init__() - + self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) self.encoder = BertModel.from_pretrained(model_dir_or_name) self._max_position_embeddings = self.encoder.config.max_position_embeddings @@ -277,19 +281,21 @@ class _WordBertModel(nn.Module): else: assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \ f"a bert model with {encoder_layer_number} layers." - + assert pool_method in ('avg', 'max', 'first', 'last') self.pool_method = pool_method self.include_cls_sep = include_cls_sep self.pooled_cls = pooled_cls self.auto_truncate = auto_truncate - + # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] logger.info("Start to generate word pieces for word.") + self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids + # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 - found_count = 0 - self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids + new_add_to_bpe_vocab = 0 + unsegment_word = 0 if '[sep]' in vocab: warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") if "[CLS]" in vocab: @@ -311,14 +317,19 @@ class _WordBertModel(nn.Module): if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 word_piece_dict[word] = 1 # 新增一个值 + new_add_to_bpe_vocab += 1 + unsegment_word += 1 continue for word_piece in word_pieces: word_piece_dict[word_piece] = 1 - found_count += 1 original_embed = self.encoder.embeddings.word_embeddings.weight.data + # 特殊词汇要特殊处理 + if not truncate_embed:# 如果不删除的话需要将已有的加上 + word_piece_dict.update(self.tokenzier.vocab) embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed new_word_piece_vocab = collections.OrderedDict() + for index, token in enumerate(['[PAD]', '[UNK]']): word_piece_dict.pop(token, None) embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] @@ -331,7 +342,11 @@ class _WordBertModel(nn.Module): new_word_piece_vocab[token] = len(new_word_piece_vocab) self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) self.encoder.embeddings.word_embeddings = embed - + if only_use_pretrain_bpe: + logger.info(f"{unsegment_word} words are unsegmented.") + else: + logger.info(f"{unsegment_word} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") + word_to_wordpieces = [] word_pieces_lengths = [] for word, index in vocab: @@ -347,11 +362,10 @@ class _WordBertModel(nn.Module): self._sep_index = self.tokenzier.vocab['[SEP]'] self._word_pad_index = vocab.padding_idx self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece - logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) self.word_to_wordpieces = np.array(word_to_wordpieces) self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) logger.debug("Successfully generate word pieces.") - + def forward(self, words): r""" @@ -376,7 +390,7 @@ class _WordBertModel(nn.Module): "After split words into word pieces, the lengths of word pieces are longer than the " f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") - + # +2是由于需要加入[CLS]与[SEP] word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), fill_value=self._wordpiece_pad_index) @@ -406,7 +420,7 @@ class _WordBertModel(nn.Module): bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, output_all_encoded_layers=True) # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size - + if self.include_cls_sep: s_shift = 1 outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index 9cc0592f..0afa76f0 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -54,6 +54,47 @@ class TestBertEmbedding(unittest.TestCase): result = embed(words) self.assertEqual(result.size(), (1, 516, 16)) + def test_bert_embedding_2(self): + # 测试only_use_pretrain_vocab与truncate_embed是否正常工作 + with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f: + num_word = len(f.readlines()) + Embedding = BertEmbedding + vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split()) + embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) + embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] + self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) + + embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) + embed_bpe_vocab_size = num_word # 排除NotInBERT + self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) + + embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) + embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] + self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) + + embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) + embed_bpe_vocab_size = num_word+1 # 新增##a + self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) + + # 测试各种情况下以下tensor的值是相等的 + embed1.eval() + embed2.eval() + embed3.eval() + embed4.eval() + tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) + t1 = embed1(tensor) + t2 = embed2(tensor) + t3 = embed3(tensor) + t4 = embed4(tensor) + + self.assertEqual((t1-t2).sum(), 0) + self.assertEqual((t1-t3).sum(), 0) + self.assertEqual((t1-t4).sum(), 0) + class TestBertWordPieceEncoder(unittest.TestCase): def test_bert_word_piece_encoder(self):