Browse Source

BertEmbedding支持选择是否缩小词表

tags/v0.5.5
yh_cc 5 years ago
parent
commit
45139ebbff
2 changed files with 86 additions and 31 deletions
  1. +45
    -31
      fastNLP/embeddings/bert_embedding.py
  2. +41
    -0
      test/embeddings/test_bert_embedding.py

+ 45
- 31
fastNLP/embeddings/bert_embedding.py View File

@@ -80,6 +80,8 @@ class BertEmbedding(ContextualEmbedding):
:param kwargs:
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新
建议设置为True。
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度)
"""
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

@@ -92,25 +94,27 @@ class BertEmbedding(ContextualEmbedding):
" faster speed.")
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
" faster speed.")
self._word_sep_index = None
if '[SEP]' in vocab:
self._word_sep_index = vocab['[SEP]']

only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
truncate_embed = kwargs.get('truncate_embed', True)
min_freq = kwargs.get('min_freq', 2)

self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2,
only_use_pretrain_bpe=only_use_pretrain_bpe)
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq,
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed)
self._sep_index = self.model._sep_index
self._cls_index = self.model._cls_index
self.requires_grad = requires_grad
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
def _delete_model_weights(self):
del self.model
def forward(self, words):
r"""
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
@@ -125,9 +129,9 @@ class BertEmbedding(ContextualEmbedding):
return self.dropout(outputs)
outputs = self.model(words)
outputs = torch.cat([*outputs], dim=-1)
return self.dropout(outputs)
def drop_word(self, words):
r"""
按照设定随机将words设置为unknown_index。
@@ -167,11 +171,11 @@ class BertWordPieceEncoder(nn.Module):
multi-base-uncased: multilingual uncased

"""
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool = True):
r"""
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。
@@ -180,7 +184,7 @@ class BertWordPieceEncoder(nn.Module):
:param bool requires_grad: 是否需要gradient。
"""
super().__init__()
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
self._sep_index = self.model._sep_index
self._cls_index = self.model._cls_index
@@ -190,19 +194,19 @@ class BertWordPieceEncoder(nn.Module):
self.requires_grad = requires_grad
self.word_dropout = word_dropout
self.dropout_layer = nn.Dropout(dropout)
@property
def embed_size(self):
return self._embed_size
@property
def embedding_dim(self):
return self._embed_size
@property
def num_embedding(self):
return self.model.encoder.config.vocab_size
def index_datasets(self, *datasets, field_name, add_cls_sep=True):
r"""
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
@@ -214,7 +218,7 @@ class BertWordPieceEncoder(nn.Module):
:return:
"""
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep)
def forward(self, word_pieces, token_type_ids=None):
r"""
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
@@ -231,13 +235,13 @@ class BertWordPieceEncoder(nn.Module):
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()
word_pieces = self.drop_word(word_pieces)
outputs = self.model(word_pieces, token_type_ids)
outputs = torch.cat([*outputs], dim=-1)
return self.dropout_layer(outputs)
def drop_word(self, words):
r"""
按照设定随机将words设置为unknown_index。
@@ -261,9 +265,9 @@ class BertWordPieceEncoder(nn.Module):
class _WordBertModel(nn.Module):
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
only_use_pretrain_bpe=False):
only_use_pretrain_bpe=False, truncate_embed=True):
super().__init__()
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
@@ -277,19 +281,21 @@ class _WordBertModel(nn.Module):
else:
assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
assert pool_method in ('avg', 'max', 'first', 'last')
self.pool_method = pool_method
self.include_cls_sep = include_cls_sep
self.pooled_cls = pooled_cls
self.auto_truncate = auto_truncate
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
logger.info("Start to generate word pieces for word.")
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids

# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的
found_count = 0
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
new_add_to_bpe_vocab = 0
unsegment_word = 0
if '[sep]' in vocab:
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.")
if "[CLS]" in vocab:
@@ -311,14 +317,19 @@ class _WordBertModel(nn.Module):
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增
word_piece_dict[word] = 1 # 新增一个值
new_add_to_bpe_vocab += 1
unsegment_word += 1
continue
for word_piece in word_pieces:
word_piece_dict[word_piece] = 1
found_count += 1
original_embed = self.encoder.embeddings.word_embeddings.weight.data

# 特殊词汇要特殊处理
if not truncate_embed:# 如果不删除的话需要将已有的加上
word_piece_dict.update(self.tokenzier.vocab)
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed
new_word_piece_vocab = collections.OrderedDict()

for index, token in enumerate(['[PAD]', '[UNK]']):
word_piece_dict.pop(token, None)
embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]]
@@ -331,7 +342,11 @@ class _WordBertModel(nn.Module):
new_word_piece_vocab[token] = len(new_word_piece_vocab)
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab)
self.encoder.embeddings.word_embeddings = embed
if only_use_pretrain_bpe:
logger.info(f"{unsegment_word} words are unsegmented.")
else:
logger.info(f"{unsegment_word} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.")

word_to_wordpieces = []
word_pieces_lengths = []
for word, index in vocab:
@@ -347,11 +362,10 @@ class _WordBertModel(nn.Module):
self._sep_index = self.tokenzier.vocab['[SEP]']
self._word_pad_index = vocab.padding_idx
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self.word_to_wordpieces = np.array(word_to_wordpieces)
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
logger.debug("Successfully generate word pieces.")
def forward(self, words):
r"""

@@ -376,7 +390,7 @@ class _WordBertModel(nn.Module):
"After split words into word pieces, the lengths of word pieces are longer than the "
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
# +2是由于需要加入[CLS]与[SEP]
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
fill_value=self._wordpiece_pad_index)
@@ -406,7 +420,7 @@ class _WordBertModel(nn.Module):
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
if self.include_cls_sep:
s_shift = 1
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,


+ 41
- 0
test/embeddings/test_bert_embedding.py View File

@@ -54,6 +54,47 @@ class TestBertEmbedding(unittest.TestCase):
result = embed(words)
self.assertEqual(result.size(), (1, 516, 16))

def test_bert_embedding_2(self):
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作
with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f:
num_word = len(f.readlines())
Embedding = BertEmbedding
vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split())
embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1)
embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS]
self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab))

embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1)
embed_bpe_vocab_size = num_word # 排除NotInBERT
self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab))

embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1)
embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS]
self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab))

embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1)
embed_bpe_vocab_size = num_word+1 # 新增##a
self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab))

# 测试各种情况下以下tensor的值是相等的
embed1.eval()
embed2.eval()
embed3.eval()
embed4.eval()
tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]])
t1 = embed1(tensor)
t2 = embed2(tensor)
t3 = embed3(tensor)
t4 = embed4(tensor)

self.assertEqual((t1-t2).sum(), 0)
self.assertEqual((t1-t3).sum(), 0)
self.assertEqual((t1-t4).sum(), 0)


class TestBertWordPieceEncoder(unittest.TestCase):
def test_bert_word_piece_encoder(self):


Loading…
Cancel
Save