|
|
@@ -80,6 +80,8 @@ class BertEmbedding(ContextualEmbedding): |
|
|
|
:param kwargs: |
|
|
|
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 |
|
|
|
建议设置为True。 |
|
|
|
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 |
|
|
|
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) |
|
|
|
""" |
|
|
|
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
@@ -92,25 +94,27 @@ class BertEmbedding(ContextualEmbedding): |
|
|
|
" faster speed.") |
|
|
|
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" |
|
|
|
" faster speed.") |
|
|
|
|
|
|
|
|
|
|
|
self._word_sep_index = None |
|
|
|
if '[SEP]' in vocab: |
|
|
|
self._word_sep_index = vocab['[SEP]'] |
|
|
|
|
|
|
|
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) |
|
|
|
|
|
|
|
truncate_embed = kwargs.get('truncate_embed', True) |
|
|
|
min_freq = kwargs.get('min_freq', 2) |
|
|
|
|
|
|
|
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, |
|
|
|
pool_method=pool_method, include_cls_sep=include_cls_sep, |
|
|
|
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2, |
|
|
|
only_use_pretrain_bpe=only_use_pretrain_bpe) |
|
|
|
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, |
|
|
|
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) |
|
|
|
self._sep_index = self.model._sep_index |
|
|
|
self._cls_index = self.model._cls_index |
|
|
|
self.requires_grad = requires_grad |
|
|
|
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size |
|
|
|
|
|
|
|
|
|
|
|
def _delete_model_weights(self): |
|
|
|
del self.model |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, words): |
|
|
|
r""" |
|
|
|
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 |
|
|
@@ -125,9 +129,9 @@ class BertEmbedding(ContextualEmbedding): |
|
|
|
return self.dropout(outputs) |
|
|
|
outputs = self.model(words) |
|
|
|
outputs = torch.cat([*outputs], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
return self.dropout(outputs) |
|
|
|
|
|
|
|
|
|
|
|
def drop_word(self, words): |
|
|
|
r""" |
|
|
|
按照设定随机将words设置为unknown_index。 |
|
|
@@ -167,11 +171,11 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
multi-base-uncased: multilingual uncased |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, |
|
|
|
word_dropout=0, dropout=0, requires_grad: bool = True): |
|
|
|
r""" |
|
|
|
|
|
|
|
|
|
|
|
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` |
|
|
|
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 |
|
|
|
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。 |
|
|
@@ -180,7 +184,7 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
:param bool requires_grad: 是否需要gradient。 |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) |
|
|
|
self._sep_index = self.model._sep_index |
|
|
|
self._cls_index = self.model._cls_index |
|
|
@@ -190,19 +194,19 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
self.requires_grad = requires_grad |
|
|
|
self.word_dropout = word_dropout |
|
|
|
self.dropout_layer = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def embed_size(self): |
|
|
|
return self._embed_size |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def embedding_dim(self): |
|
|
|
return self._embed_size |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def num_embedding(self): |
|
|
|
return self.model.encoder.config.vocab_size |
|
|
|
|
|
|
|
|
|
|
|
def index_datasets(self, *datasets, field_name, add_cls_sep=True): |
|
|
|
r""" |
|
|
|
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 |
|
|
@@ -214,7 +218,7 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, word_pieces, token_type_ids=None): |
|
|
|
r""" |
|
|
|
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 |
|
|
@@ -231,13 +235,13 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
token_type_ids = sep_mask_cumsum.fmod(2) |
|
|
|
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 |
|
|
|
token_type_ids = token_type_ids.eq(0).long() |
|
|
|
|
|
|
|
|
|
|
|
word_pieces = self.drop_word(word_pieces) |
|
|
|
outputs = self.model(word_pieces, token_type_ids) |
|
|
|
outputs = torch.cat([*outputs], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
return self.dropout_layer(outputs) |
|
|
|
|
|
|
|
|
|
|
|
def drop_word(self, words): |
|
|
|
r""" |
|
|
|
按照设定随机将words设置为unknown_index。 |
|
|
@@ -261,9 +265,9 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
class _WordBertModel(nn.Module): |
|
|
|
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', |
|
|
|
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, |
|
|
|
only_use_pretrain_bpe=False): |
|
|
|
only_use_pretrain_bpe=False, truncate_embed=True): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) |
|
|
|
self.encoder = BertModel.from_pretrained(model_dir_or_name) |
|
|
|
self._max_position_embeddings = self.encoder.config.max_position_embeddings |
|
|
@@ -277,19 +281,21 @@ class _WordBertModel(nn.Module): |
|
|
|
else: |
|
|
|
assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \ |
|
|
|
f"a bert model with {encoder_layer_number} layers." |
|
|
|
|
|
|
|
|
|
|
|
assert pool_method in ('avg', 'max', 'first', 'last') |
|
|
|
self.pool_method = pool_method |
|
|
|
self.include_cls_sep = include_cls_sep |
|
|
|
self.pooled_cls = pooled_cls |
|
|
|
self.auto_truncate = auto_truncate |
|
|
|
|
|
|
|
|
|
|
|
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] |
|
|
|
logger.info("Start to generate word pieces for word.") |
|
|
|
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids |
|
|
|
|
|
|
|
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 |
|
|
|
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 |
|
|
|
found_count = 0 |
|
|
|
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids |
|
|
|
new_add_to_bpe_vocab = 0 |
|
|
|
unsegment_word = 0 |
|
|
|
if '[sep]' in vocab: |
|
|
|
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") |
|
|
|
if "[CLS]" in vocab: |
|
|
@@ -311,14 +317,19 @@ class _WordBertModel(nn.Module): |
|
|
|
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( |
|
|
|
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 |
|
|
|
word_piece_dict[word] = 1 # 新增一个值 |
|
|
|
new_add_to_bpe_vocab += 1 |
|
|
|
unsegment_word += 1 |
|
|
|
continue |
|
|
|
for word_piece in word_pieces: |
|
|
|
word_piece_dict[word_piece] = 1 |
|
|
|
found_count += 1 |
|
|
|
original_embed = self.encoder.embeddings.word_embeddings.weight.data |
|
|
|
|
|
|
|
# 特殊词汇要特殊处理 |
|
|
|
if not truncate_embed:# 如果不删除的话需要将已有的加上 |
|
|
|
word_piece_dict.update(self.tokenzier.vocab) |
|
|
|
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed |
|
|
|
new_word_piece_vocab = collections.OrderedDict() |
|
|
|
|
|
|
|
for index, token in enumerate(['[PAD]', '[UNK]']): |
|
|
|
word_piece_dict.pop(token, None) |
|
|
|
embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] |
|
|
@@ -331,7 +342,11 @@ class _WordBertModel(nn.Module): |
|
|
|
new_word_piece_vocab[token] = len(new_word_piece_vocab) |
|
|
|
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) |
|
|
|
self.encoder.embeddings.word_embeddings = embed |
|
|
|
|
|
|
|
if only_use_pretrain_bpe: |
|
|
|
logger.info(f"{unsegment_word} words are unsegmented.") |
|
|
|
else: |
|
|
|
logger.info(f"{unsegment_word} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") |
|
|
|
|
|
|
|
word_to_wordpieces = [] |
|
|
|
word_pieces_lengths = [] |
|
|
|
for word, index in vocab: |
|
|
@@ -347,11 +362,10 @@ class _WordBertModel(nn.Module): |
|
|
|
self._sep_index = self.tokenzier.vocab['[SEP]'] |
|
|
|
self._word_pad_index = vocab.padding_idx |
|
|
|
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece |
|
|
|
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) |
|
|
|
self.word_to_wordpieces = np.array(word_to_wordpieces) |
|
|
|
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) |
|
|
|
logger.debug("Successfully generate word pieces.") |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, words): |
|
|
|
r""" |
|
|
|
|
|
|
@@ -376,7 +390,7 @@ class _WordBertModel(nn.Module): |
|
|
|
"After split words into word pieces, the lengths of word pieces are longer than the " |
|
|
|
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " |
|
|
|
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") |
|
|
|
|
|
|
|
|
|
|
|
# +2是由于需要加入[CLS]与[SEP] |
|
|
|
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), |
|
|
|
fill_value=self._wordpiece_pad_index) |
|
|
@@ -406,7 +420,7 @@ class _WordBertModel(nn.Module): |
|
|
|
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, |
|
|
|
output_all_encoded_layers=True) |
|
|
|
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size |
|
|
|
|
|
|
|
|
|
|
|
if self.include_cls_sep: |
|
|
|
s_shift = 1 |
|
|
|
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, |
|
|
|