| @@ -22,6 +22,7 @@ from .embedding import Embedding, TokenEmbedding | |||
| from .static_embedding import StaticEmbedding | |||
| from .elmo_embedding import ElmoEmbedding | |||
| from .bert_embedding import BertEmbedding, BertWordPieceEncoder | |||
| from .roberta_embedding import RobertaEmbedding | |||
| from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | |||
| from .stack_embedding import StackEmbedding | |||
| from .utils import get_embeddings | |||
| @@ -0,0 +1,339 @@ | |||
| import os | |||
| import collections | |||
| import warnings | |||
| from itertools import chain | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| from .contextual_embedding import ContextualEmbedding | |||
| from ..core import logger, Vocabulary | |||
| from ..modules.encoder.roberta import RobertaModel, RobertaTokenizer | |||
| class RobertaEmbedding(ContextualEmbedding): | |||
| r""" | |||
| 使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | |||
| 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有RobertaEmbedding在输入word | |||
| 时切分),在分割之后长度可能会超过最大长度限制。 | |||
| RobertaEmbedding可以支持自动下载权重,当前支持的模型: | |||
| ..TODO | |||
| Example:: | |||
| >>> import torch | |||
| >>> from fastNLP import Vocabulary | |||
| >>> from fastNLP.embeddings import RobertaEmbedding | |||
| >>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
| >>> embed = RobertaEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') | |||
| >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||
| >>> outputs = embed(words) | |||
| >>> outputs.size() | |||
| >>> # torch.Size([1, 5, 2304]) | |||
| """ | |||
| def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | |||
| pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | |||
| pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs): | |||
| r""" | |||
| :param ~fastNLP.Vocabulary vocab: 词表 | |||
| :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件 | |||
| (以vocab.json作为后缀名), 权重文件(以.bin作为文件后缀名), 配置文件(以config.json作为后缀名)。 | |||
| :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 | |||
| 从0开始,可以以负数去索引倒数几层。 | |||
| :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | |||
| 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 | |||
| :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
| :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
| :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | |||
| 会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的 | |||
| embedding长度不匹配。 | |||
| :param bool pooled_cls: 返回的<s>是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取<s>做预测, | |||
| 一般该值为True。 | |||
| :param bool requires_grad: 是否需要gradient以更新Bert的权重。 | |||
| :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 | |||
| word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | |||
| 来进行分类的任务将auto_truncate置为True。 | |||
| :param kwargs: | |||
| bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||
| 建议设置为True。 | |||
| """ | |||
| super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
| if word_dropout > 0: | |||
| assert vocab.unknown is not None, "When word_drop > 0, Vocabulary must contain the unknown token." | |||
| self._word_sep_index = None | |||
| if '</s>' in vocab: | |||
| self._word_sep_index = vocab['</s>'] | |||
| only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||
| self.model = _WordRobertaModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
| pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
| pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2, | |||
| only_use_pretrain_bpe=only_use_pretrain_bpe) | |||
| self._sep_index = self.model._sep_index | |||
| self._cls_index = self.model._cls_index | |||
| self.requires_grad = requires_grad | |||
| self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
| def _delete_model_weights(self): | |||
| del self.model | |||
| def forward(self, words): | |||
| r""" | |||
| 计算words的roberta embedding表示。计算之前会在每句话的开始增加<s>在结束增加</s>, 并根据include_cls_sep判断要不要 | |||
| 删除这两个token的表示。 | |||
| :param torch.LongTensor words: [batch_size, max_len] | |||
| :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
| """ | |||
| words = self.drop_word(words) | |||
| outputs = self._get_sent_reprs(words) | |||
| if outputs is not None: | |||
| return self.dropout(outputs) | |||
| outputs = self.model(words) | |||
| outputs = torch.cat([*outputs], dim=-1) | |||
| return self.dropout(outputs) | |||
| def drop_word(self, words): | |||
| r""" | |||
| 按照设定随机将words设置为unknown_index。 | |||
| :param torch.LongTensor words: batch_size x max_len | |||
| :return: | |||
| """ | |||
| if self.word_dropout > 0 and self.training: | |||
| with torch.no_grad(): | |||
| not_sep_mask = words.ne(self._sep_index) | |||
| not_cls_mask = words.ne(self._cls_index) | |||
| if self._word_sep_index: | |||
| not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index)) | |||
| replaceable_mask = not_sep_mask.__and__(not_cls_mask) | |||
| mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||
| mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
| pad_mask = words.ne(self._word_pad_index) | |||
| mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk | |||
| words = words.masked_fill(mask, self._word_unk_index) | |||
| return words | |||
| class _WordRobertaModel(nn.Module): | |||
| def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
| include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
| only_use_pretrain_bpe=False): | |||
| super().__init__() | |||
| self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||
| self.encoder = RobertaModel.from_pretrained(model_dir_or_name) | |||
| self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||
| # 检查encoder_layer_number是否合理 | |||
| encoder_layer_number = len(self.encoder.encoder.layer) | |||
| self.layers = list(map(int, layers.split(','))) | |||
| for layer in self.layers: | |||
| if layer < 0: | |||
| assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
| f"a roberta model with {encoder_layer_number} layers." | |||
| else: | |||
| assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
| f"a roberta model with {encoder_layer_number} layers." | |||
| assert pool_method in ('avg', 'max', 'first', 'last') | |||
| self.pool_method = pool_method | |||
| self.include_cls_sep = include_cls_sep | |||
| self.pooled_cls = pooled_cls | |||
| self.auto_truncate = auto_truncate | |||
| # 将所有vocab中word的wordpiece计算出来, 需要额外考虑<s>和</s> | |||
| logger.info("Start to generate word pieces for word.") | |||
| # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
| word_piece_dict = {'<s>': 1, '</s>': 1} # 用到的word_piece以及新增的 | |||
| found_count = 0 | |||
| self._has_sep_in_vocab = '</s>' in vocab # 用来判断传入的数据是否需要生成token_ids | |||
| if "<s>" in vocab: | |||
| warnings.warn("<s> detected in your vocabulary. RobertaEmbedding will add <s> and </s> to the begin " | |||
| "and end of the input automatically, make sure you don't add <s> and </s> at the begin" | |||
| " and end.") | |||
| for word, index in vocab: | |||
| if index == vocab.padding_idx: # pad是个特殊的符号 | |||
| word = '<pad>' | |||
| elif index == vocab.unknown_idx: | |||
| word = '<unk>' | |||
| # _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容 | |||
| word_pieces = [] | |||
| word_pieces.extend(self.tokenzier.tokenize(word)) | |||
| if len(word_pieces) == 1: | |||
| if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
| if index != vocab.unknown_idx and word_pieces[0] == '<unk>': # 说明这个词不在原始的word里面 | |||
| if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||
| word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||
| word_piece_dict[word] = 1 # 新增一个值 | |||
| continue | |||
| for word_piece in word_pieces: | |||
| word_piece_dict[word_piece] = 1 | |||
| found_count += 1 | |||
| original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||
| # 特殊词汇要特殊处理 | |||
| embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||
| new_word_piece_vocab = collections.OrderedDict() | |||
| for index, token in enumerate(['<pad>', '<unk>']): | |||
| word_piece_dict.pop(token, None) | |||
| embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]] | |||
| new_word_piece_vocab[token] = index | |||
| for token in word_piece_dict.keys(): | |||
| if token in self.tokenzier.encoder: | |||
| embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.encoder[token]] | |||
| else: | |||
| embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.encoder['<unk>']] | |||
| new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
| self._reinit_on_new_vocab(new_word_piece_vocab, model_dir_or_name) | |||
| self.encoder.embeddings.word_embeddings = embed | |||
| word_to_wordpieces = [] | |||
| word_pieces_lengths = [] | |||
| for word, index in vocab: | |||
| if index == vocab.padding_idx: # pad是个特殊的符号 | |||
| word = '<pad>' | |||
| elif index == vocab.unknown_idx: | |||
| word = '<unk>' | |||
| word_pieces = self.tokenzier.tokenize(word) | |||
| word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
| word_to_wordpieces.append(word_pieces) | |||
| word_pieces_lengths.append(len(word_pieces)) | |||
| self._cls_index = self.tokenzier.encoder['<s>'] | |||
| self._sep_index = self.tokenzier.encoder['</s>'] | |||
| self._word_pad_index = vocab.padding_idx | |||
| self._wordpiece_pad_index = self.tokenzier.encoder['<pad>'] # 需要用于生成word_piece | |||
| logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
| self.word_to_wordpieces = np.array(word_to_wordpieces) | |||
| self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | |||
| logger.debug("Successfully generate word pieces.") | |||
| def _reinit_on_new_vocab(self, vocab, model_dir_or_name): | |||
| import json | |||
| with open('./.tmp-new-vocab-file.json', 'w') as f: | |||
| json.dump(vocab, f) | |||
| self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name, vocab_file='./.tmp-new-vocab-file.json') | |||
| os.remove('./.tmp-new-vocab-file.json') | |||
| def forward(self, words): | |||
| r""" | |||
| :param words: torch.LongTensor, batch_size x max_len | |||
| :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||
| """ | |||
| with torch.no_grad(): | |||
| batch_size, max_word_len = words.size() | |||
| word_mask = words.ne(self._word_pad_index) # 为1的地方有word | |||
| seq_len = word_mask.sum(dim=-1) | |||
| batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), | |||
| 0) # batch_size x max_len | |||
| word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | |||
| word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||
| if word_piece_length + 2 > self._max_position_embeddings: | |||
| if self.auto_truncate: | |||
| word_pieces_lengths = word_pieces_lengths.masked_fill( | |||
| word_pieces_lengths + 2 > self._max_position_embeddings, | |||
| self._max_position_embeddings - 2) | |||
| else: | |||
| raise RuntimeError( | |||
| "After split words into word pieces, the lengths of word pieces are longer than the " | |||
| f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " | |||
| f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | |||
| # +2是由于需要加入<s>与</s> | |||
| word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), | |||
| fill_value=self._wordpiece_pad_index) | |||
| attn_masks = torch.zeros_like(word_pieces) | |||
| # 1. 获取words的word_pieces的id,以及对应的span范围 | |||
| word_indexes = words.cpu().numpy() | |||
| for i in range(batch_size): | |||
| word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) | |||
| if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2: | |||
| word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2] | |||
| word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i) | |||
| attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1) | |||
| # 添加[cls]和[sep] | |||
| word_pieces[:, 0].fill_(self._cls_index) | |||
| batch_indexes = torch.arange(batch_size).to(words) | |||
| word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index | |||
| # if self._has_sep_in_vocab: # 但</s>在vocab中出现应该才会需要token_ids | |||
| # sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len | |||
| # sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) | |||
| # token_type_ids = sep_mask_cumsum.fmod(2) | |||
| # if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 | |||
| # token_type_ids = token_type_ids.eq(0).long() | |||
| # else: # RoBERTa不需要额外设置token_type_ids | |||
| token_type_ids = torch.zeros_like(word_pieces) | |||
| # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||
| # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||
| bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, | |||
| attention_mask=attn_masks, | |||
| output_all_encoded_layers=True) | |||
| # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | |||
| if self.include_cls_sep: | |||
| s_shift = 1 | |||
| outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
| bert_outputs[-1].size(-1)) | |||
| else: | |||
| s_shift = 0 | |||
| outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | |||
| bert_outputs[-1].size(-1)) | |||
| batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | |||
| batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||
| if self.pool_method == 'first': | |||
| batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | |||
| batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
| _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
| elif self.pool_method == 'last': | |||
| batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max() + 1] - 1 | |||
| batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
| _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
| for l_index, l in enumerate(self.layers): | |||
| output_layer = bert_outputs[l] | |||
| real_word_piece_length = output_layer.size(1) - 2 | |||
| if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 | |||
| paddings = output_layer.new_zeros(batch_size, | |||
| word_piece_length - real_word_piece_length, | |||
| output_layer.size(2)) | |||
| output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | |||
| # 从word_piece collapse到word的表示 | |||
| truncate_output_layer = output_layer[:, 1:-1] # 删除<s>与</s> batch_size x len x hidden_size | |||
| if self.pool_method == 'first': | |||
| tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||
| tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||
| outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp | |||
| elif self.pool_method == 'last': | |||
| tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||
| tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||
| outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp | |||
| elif self.pool_method == 'max': | |||
| for i in range(batch_size): | |||
| for j in range(seq_len[i]): | |||
| start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||
| outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2) | |||
| else: | |||
| for i in range(batch_size): | |||
| for j in range(seq_len[i]): | |||
| start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||
| outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | |||
| if self.include_cls_sep: | |||
| if l in (len(bert_outputs) - 1, -1) and self.pooled_cls: | |||
| outputs[l_index, :, 0] = pooled_cls | |||
| else: | |||
| outputs[l_index, :, 0] = output_layer[:, 0] | |||
| outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] | |||
| # 3. 最终的embedding结果 | |||
| return outputs | |||
| @@ -34,6 +34,7 @@ __all__ = [ | |||
| from .attention import MultiHeadAttention, BiAttention, SelfAttention | |||
| from .bert import BertModel | |||
| from .roberta import RobertaModel | |||
| from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
| from .conv_maxpool import ConvMaxpool | |||
| from .lstm import LSTM | |||
| @@ -245,14 +245,18 @@ class BertEmbeddings(nn.Module): | |||
| self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| def forward(self, input_ids, token_type_ids=None): | |||
| def forward(self, input_ids, token_type_ids=None, position_ids=None, words_embeddings=None): | |||
| seq_length = input_ids.size(1) | |||
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
| if position_ids is None: | |||
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
| if token_type_ids is None: | |||
| token_type_ids = torch.zeros_like(input_ids) | |||
| words_embeddings = self.word_embeddings(input_ids) | |||
| if words_embeddings is None: | |||
| words_embeddings = self.word_embeddings(input_ids) | |||
| else: | |||
| assert input_ids.size() == words_embeddings.size()[: -1] | |||
| position_embeddings = self.position_embeddings(position_ids) | |||
| token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
| @@ -0,0 +1,773 @@ | |||
| from functools import lru_cache | |||
| import json | |||
| import regex as re | |||
| import itertools | |||
| from ...io.file_utils import _get_embedding_url, cached_path | |||
| from ...core import logger | |||
| import os | |||
| PRETRAINED_GPT2_MODEL_DIR = PRETRAINED_BERT_MODEL_DIR = { | |||
| 'en-small': 'gpt2-small.zip', | |||
| 'en-median': 'gpt2-medium.zip', | |||
| 'en': 'gpt2-medium.zip' | |||
| } | |||
| def _get_gpt2_dir(model_dir_or_name: str = 'en-median'): | |||
| if model_dir_or_name.lower() in PRETRAINED_GPT2_MODEL_DIR: | |||
| model_url = _get_embedding_url('gpt2', model_dir_or_name.lower()) | |||
| model_dir = cached_path(model_url, name='embedding') | |||
| # 检查是否存在 | |||
| elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
| model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
| else: | |||
| logger.error(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") | |||
| raise ValueError(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") | |||
| return str(model_dir) | |||
| def _get_filepath_based_on_postfix(folder, postfix): | |||
| """ | |||
| 在folder下寻找结尾为postfix的文件. 比如寻找结尾为vocab.txt的文件。只会匹配第一个,如果有多个不会报错,没有找到会报错。 | |||
| 返回该文件的全路径 | |||
| :param str folder: | |||
| :param str postfix: | |||
| :return: | |||
| """ | |||
| for filename in os.listdir(folder): | |||
| if os.path.isfile(os.path.join(folder, filename)): | |||
| if filename.endswith(postfix): | |||
| return os.path.join(folder, filename) | |||
| raise FileNotFoundError(f"File {postfix} is not found in {folder}.") | |||
| @lru_cache() | |||
| def bytes_to_unicode(): | |||
| """ | |||
| Returns list of utf-8 byte and a mapping to unicode strings. | |||
| We specifically avoids mapping to whitespace/control characters the bpe code barfs on. | |||
| The reversible bpe codes work on unicode strings. | |||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||
| """ | |||
| bs = ( | |||
| list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) | |||
| ) | |||
| cs = bs[:] | |||
| n = 0 | |||
| for b in range(2 ** 8): | |||
| if b not in bs: | |||
| bs.append(b) | |||
| cs.append(2 ** 8 + n) | |||
| n += 1 | |||
| cs = [chr(n) for n in cs] | |||
| return dict(zip(bs, cs)) | |||
| def get_pairs(word): | |||
| """Return set of symbol pairs in a word. | |||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||
| """ | |||
| pairs = set() | |||
| prev_char = word[0] | |||
| for char in word[1:]: | |||
| pairs.add((prev_char, char)) | |||
| prev_char = char | |||
| return pairs | |||
| VOCAB_FILES_NAMES = { | |||
| "vocab_file": "vocab.json", | |||
| "merges_file": "merges.txt", | |||
| } | |||
| PRETRAINED_VOCAB_FILES_MAP = { | |||
| "vocab_file": { | |||
| "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", | |||
| "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", | |||
| "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", | |||
| "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json", | |||
| "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json", | |||
| }, | |||
| "merges_file": { | |||
| "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", | |||
| "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", | |||
| "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", | |||
| "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt", | |||
| "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt", | |||
| }, | |||
| } | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
| "en-small": 1024, | |||
| 'en': 1024, | |||
| "en-medium": 1024, | |||
| "en-large": 1024, | |||
| "en-xl": 1024, | |||
| "en-distilgpt2": 1024, | |||
| } | |||
| PATTERN = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | |||
| def gpt2_tokenize(text, add_prefix_space=True): | |||
| """ | |||
| :param str text: | |||
| :param bool add_prefix_space: 是否在句子前面加上space,如果加上才能保证与GPT2训练时一致 | |||
| :return: [] | |||
| """ | |||
| if text is '': | |||
| return [] | |||
| if add_prefix_space: | |||
| text = ' ' + text | |||
| tokens = [] | |||
| for token in re.findall(PATTERN, text): | |||
| tokens.append(token) | |||
| return tokens | |||
| class GPT2Tokenizer: | |||
| """ | |||
| GPT-2 BPE tokenizer. Peculiarities: | |||
| - Byte-level Byte-Pair-Encoding | |||
| - Requires a space to start the input string => the encoding and tokenize methods should be called with the | |||
| ``add_prefix_space`` flag set to ``True``. | |||
| Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve | |||
| the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"` | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| SPECIAL_TOKENS_ATTRIBUTES = [ | |||
| "bos_token", | |||
| "eos_token", | |||
| "unk_token", | |||
| "pad_token", | |||
| "cls_token", | |||
| "mask_token", | |||
| ] | |||
| padding_side = "right" | |||
| def __init__( | |||
| self, | |||
| vocab_file, | |||
| merges_file, | |||
| errors="replace", | |||
| unk_token="<|endoftext|>", | |||
| bos_token="<|endoftext|>", | |||
| eos_token="<|endoftext|>", | |||
| **kwargs | |||
| ): | |||
| self._bos_token = None | |||
| self._eos_token = None | |||
| self._unk_token = None | |||
| self._sep_token = None | |||
| self._pad_token = None | |||
| self._cls_token = None | |||
| self._mask_token = None | |||
| self._pad_token_type_id = 0 | |||
| self.bos_token = bos_token | |||
| self.eos_token = eos_token | |||
| self.unk_token = unk_token | |||
| self.max_len = int(1e12) | |||
| self.padding_side = kwargs.pop("padding_side", self.padding_side) | |||
| self.added_tokens_encoder = {} | |||
| self.unique_added_tokens_encoder = set() | |||
| self.added_tokens_decoder = {} | |||
| # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) | |||
| self.init_inputs = () | |||
| self.init_kwargs = {} | |||
| for key, value in kwargs.items(): | |||
| if key in self.SPECIAL_TOKENS_ATTRIBUTES: | |||
| if key == "additional_special_tokens": | |||
| assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) | |||
| else: | |||
| assert isinstance(value, str) | |||
| setattr(self, key, value) | |||
| self.max_len_single_sentence = ( | |||
| self.max_len | |||
| ) # no default special tokens - you can update this value if you add special tokens | |||
| self.max_len_sentences_pair = ( | |||
| self.max_len | |||
| ) # no default special tokens - you can update this value if you add special tokens | |||
| with open(vocab_file, encoding="utf-8") as vocab_handle: | |||
| self.encoder = json.load(vocab_handle) | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.errors = errors # how to handle errors in decoding | |||
| self.byte_encoder = bytes_to_unicode() | |||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||
| with open(merges_file, encoding="utf-8") as merges_handle: | |||
| bpe_merges = merges_handle.read().split("\n")[1:-1] | |||
| bpe_merges = [tuple(merge.split()) for merge in bpe_merges] | |||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||
| self.cache = {} | |||
| def add_special_tokens(self, special_tokens_dict): | |||
| """ | |||
| Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them | |||
| to class attributes. If special tokens are NOT in the vocabulary, they are added | |||
| to it (indexed starting from the last index of the current vocabulary). | |||
| Using `add_special_tokens` will ensure your special tokens can be used in several ways: | |||
| - special tokens are carefully handled by the tokenizer (they are never split) | |||
| - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts. | |||
| When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '</s>') | |||
| Args: | |||
| special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: | |||
| [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, | |||
| ``additional_special_tokens``]. | |||
| Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). | |||
| Returns: | |||
| Number of tokens added to the vocabulary. | |||
| Examples:: | |||
| # Let's see how to add a new classification token to GPT-2 | |||
| tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |||
| model = GPT2Model.from_pretrained('gpt2') | |||
| special_tokens_dict = {'cls_token': '<CLS>'} | |||
| num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) | |||
| print('We have added', num_added_toks, 'tokens') | |||
| model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. | |||
| assert tokenizer.cls_token == '<CLS>' | |||
| """ | |||
| if not special_tokens_dict: | |||
| return 0 | |||
| added_tokens = 0 | |||
| for key, value in special_tokens_dict.items(): | |||
| assert key in self.SPECIAL_TOKENS_ATTRIBUTES | |||
| if key == "additional_special_tokens": | |||
| assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) | |||
| added_tokens += self.add_tokens(value) | |||
| else: | |||
| assert isinstance(value, str) | |||
| added_tokens += self.add_tokens([value]) | |||
| logger.debug("Assigning %s to the %s key of the tokenizer", value, key) | |||
| setattr(self, key, value) | |||
| return added_tokens | |||
| def add_tokens(self, new_tokens): | |||
| """ | |||
| Add a list of new tokens to the tokenizer class. If the new tokens are not in the | |||
| vocabulary, they are added to it with indices starting from length of the current vocabulary. | |||
| Args: | |||
| new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). | |||
| Returns: | |||
| Number of tokens added to the vocabulary. | |||
| Examples:: | |||
| # Let's see how to increase the vocabulary of Bert model and tokenizer | |||
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |||
| model = BertModel.from_pretrained('bert-base-uncased') | |||
| num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) | |||
| print('We have added', num_added_toks, 'tokens') | |||
| model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. | |||
| """ | |||
| if not new_tokens: | |||
| return 0 | |||
| to_add_tokens = [] | |||
| for token in new_tokens: | |||
| assert isinstance(token, str) | |||
| if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens: | |||
| token = token.lower() | |||
| if ( | |||
| token != self.unk_token | |||
| and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) | |||
| and token not in to_add_tokens | |||
| ): | |||
| to_add_tokens.append(token) | |||
| logger.debug("Adding %s to the vocabulary", token) | |||
| added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) | |||
| added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} | |||
| self.added_tokens_encoder.update(added_tok_encoder) | |||
| self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens)) | |||
| self.added_tokens_decoder.update(added_tok_decoder) | |||
| return len(to_add_tokens) | |||
| @property | |||
| def bos_token(self): | |||
| """ Beginning of sentence token (string). Log an error if used while not having been set. """ | |||
| if self._bos_token is None: | |||
| logger.error("Using bos_token, but it is not set yet.") | |||
| return self._bos_token | |||
| @property | |||
| def eos_token(self): | |||
| """ End of sentence token (string). Log an error if used while not having been set. """ | |||
| if self._eos_token is None: | |||
| logger.error("Using eos_token, but it is not set yet.") | |||
| return self._eos_token | |||
| @property | |||
| def unk_token(self): | |||
| """ Unknown token (string). Log an error if used while not having been set. """ | |||
| if self._unk_token is None: | |||
| logger.error("Using unk_token, but it is not set yet.") | |||
| return self._unk_token | |||
| @property | |||
| def pad_token(self): | |||
| """ Padding token (string). Log an error if used while not having been set. """ | |||
| if self._pad_token is None: | |||
| logger.error("Using pad_token, but it is not set yet.") | |||
| return self._pad_token | |||
| @property | |||
| def cls_token(self): | |||
| """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ | |||
| if self._cls_token is None: | |||
| logger.error("Using cls_token, but it is not set yet.") | |||
| return self._cls_token | |||
| @property | |||
| def mask_token(self): | |||
| """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ | |||
| if self._mask_token is None: | |||
| logger.error("Using mask_token, but it is not set yet.") | |||
| return self._mask_token | |||
| @bos_token.setter | |||
| def bos_token(self, value): | |||
| self._bos_token = value | |||
| @eos_token.setter | |||
| def eos_token(self, value): | |||
| self._eos_token = value | |||
| @unk_token.setter | |||
| def unk_token(self, value): | |||
| self._unk_token = value | |||
| @pad_token.setter | |||
| def pad_token(self, value): | |||
| self._pad_token = value | |||
| @cls_token.setter | |||
| def cls_token(self, value): | |||
| self._cls_token = value | |||
| @mask_token.setter | |||
| def mask_token(self, value): | |||
| self._mask_token = value | |||
| @property | |||
| def bos_token_id(self): | |||
| """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ | |||
| return self.convert_tokens_to_ids(self.bos_token) | |||
| @property | |||
| def eos_token_id(self): | |||
| """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ | |||
| return self.convert_tokens_to_ids(self.eos_token) | |||
| @property | |||
| def unk_token_id(self): | |||
| """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ | |||
| return self.convert_tokens_to_ids(self.unk_token) | |||
| @property | |||
| def pad_token_id(self): | |||
| """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """ | |||
| return self.convert_tokens_to_ids(self.pad_token) | |||
| @property | |||
| def pad_token_type_id(self): | |||
| """ Id of the padding token type in the vocabulary.""" | |||
| return self._pad_token_type_id | |||
| @property | |||
| def cls_token_id(self): | |||
| """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ | |||
| return self.convert_tokens_to_ids(self.cls_token) | |||
| @property | |||
| def mask_token_id(self): | |||
| """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ | |||
| return self.convert_tokens_to_ids(self.mask_token) | |||
| @property | |||
| def vocab_size(self): | |||
| return len(self.encoder) | |||
| def bpe(self, token): | |||
| if token in self.cache: | |||
| return self.cache[token] | |||
| word = tuple(token) | |||
| pairs = get_pairs(word) | |||
| if not pairs: | |||
| return token | |||
| while True: | |||
| bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |||
| if bigram not in self.bpe_ranks: | |||
| break | |||
| first, second = bigram | |||
| new_word = [] | |||
| i = 0 | |||
| while i < len(word): | |||
| try: | |||
| j = word.index(first, i) | |||
| except ValueError: | |||
| new_word.extend(word[i:]) | |||
| break | |||
| else: | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | |||
| new_word.append(first + second) | |||
| i += 2 | |||
| else: | |||
| new_word.append(word[i]) | |||
| i += 1 | |||
| new_word = tuple(new_word) | |||
| word = new_word | |||
| if len(word) == 1: | |||
| break | |||
| else: | |||
| pairs = get_pairs(word) | |||
| word = " ".join(word) | |||
| self.cache[token] = word | |||
| return word | |||
| def _tokenize(self, text, add_prefix_space=False): | |||
| """ Tokenize a string. | |||
| Args: | |||
| - add_prefix_space (boolean, default False): | |||
| Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers. | |||
| """ | |||
| bpe_tokens = [] | |||
| for token in gpt2_tokenize(text, add_prefix_space=add_prefix_space): | |||
| token = "".join( | |||
| self.byte_encoder[b] for b in token.encode("utf-8") | |||
| ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) | |||
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) | |||
| return bpe_tokens | |||
| def _convert_token_to_id(self, token): | |||
| """ Converts a token (str) in an id using the vocab. """ | |||
| return self.encoder.get(token, self.encoder.get(self.unk_token)) | |||
| def _convert_id_to_token(self, index): | |||
| """Converts an index (integer) in a token (str) using the vocab.""" | |||
| return self.decoder.get(index) | |||
| def convert_tokens_to_string(self, tokens): | |||
| """ Converts a sequence of tokens (string) in a single string. """ | |||
| text = "".join(tokens) | |||
| text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | |||
| return text | |||
| def save_vocabulary(self, save_directory): | |||
| """Save the tokenizer vocabulary and merge files to a directory.""" | |||
| if not os.path.isdir(save_directory): | |||
| logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) | |||
| return | |||
| vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) | |||
| merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"]) | |||
| with open(vocab_file, "w", encoding="utf-8") as f: | |||
| f.write(json.dumps(self.encoder, ensure_ascii=False)) | |||
| index = 0 | |||
| with open(merge_file, "w", encoding="utf-8") as writer: | |||
| writer.write("#version: 0.2\n") | |||
| for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): | |||
| if index != token_index: | |||
| logger.warning( | |||
| "Saving vocabulary to {}: BPE merge indices are not consecutive." | |||
| " Please check that the tokenizer is not corrupted!".format(merge_file) | |||
| ) | |||
| index = token_index | |||
| writer.write(" ".join(bpe_tokens) + "\n") | |||
| index += 1 | |||
| return vocab_file, merge_file | |||
| @classmethod | |||
| def from_pretrained(cls, model_dir_or_name): | |||
| r""" | |||
| """ | |||
| return cls._from_pretrained(model_dir_or_name) | |||
| # 将它修改一定传入文件夹 | |||
| @classmethod | |||
| def _from_pretrained(cls, model_dir_or_name): | |||
| """ | |||
| :param str model_dir_or_name: 目录或者缩写名 | |||
| :param init_inputs: | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| # 它需要两个文件,第一个是vocab.json,第二个是merge_file? | |||
| model_dir = _get_gpt2_dir(model_dir_or_name) | |||
| # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin | |||
| tokenizer_config_file = _get_filepath_based_on_postfix(model_dir, 'config.json') | |||
| with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: | |||
| init_kwargs = json.load(tokenizer_config_handle) | |||
| # Set max length if needed | |||
| if model_dir_or_name in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: | |||
| # if we're using a pretrained model, ensure the tokenizer | |||
| # wont index sequences longer than the number of positional embeddings | |||
| max_len = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name] | |||
| if max_len is not None and isinstance(max_len, (int, float)): | |||
| init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len) | |||
| # 将vocab, merge加入到init_kwargs中 | |||
| init_kwargs['vocab_file'] = _get_filepath_based_on_postfix(model_dir, 'vocab.json') | |||
| init_kwargs['merges_file'] = _get_filepath_based_on_postfix(model_dir, 'merges.txt') | |||
| init_inputs = init_kwargs.pop("init_inputs", ()) | |||
| # Instantiate tokenizer. | |||
| try: | |||
| tokenizer = cls(*init_inputs, **init_kwargs) | |||
| except OSError: | |||
| OSError( | |||
| "Unable to load vocabulary from file. " | |||
| "Please check that the provided vocabulary is accessible and not corrupted." | |||
| ) | |||
| return tokenizer | |||
| def __len__(self): | |||
| """ Size of the full vocabulary with the added tokens """ | |||
| return self.vocab_size + len(self.added_tokens_encoder) | |||
| def tokenize(self, text, add_prefix_space=True): | |||
| """ Converts a string in a sequence of tokens (string), using the tokenizer. | |||
| Split in words for word-based vocabulary or sub-words for sub-word-based | |||
| vocabularies (BPE/SentencePieces/WordPieces). | |||
| Take care of added tokens. | |||
| Args: | |||
| - text: The sequence to be encoded. | |||
| - add_prefix_space (boolean, default True): | |||
| Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers. | |||
| """ | |||
| all_special_tokens = self.all_special_tokens | |||
| def lowercase_text(t): | |||
| # convert non-special tokens to lowercase | |||
| escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens] | |||
| pattern = r'(' + r'|'.join(escaped_special_toks) + r')|' + \ | |||
| r'(.+?)' | |||
| return re.sub( | |||
| pattern, | |||
| lambda m: m.groups()[0] or m.groups()[1].lower(), | |||
| t) | |||
| if self.init_kwargs.get('do_lower_case', False): | |||
| text = lowercase_text(text) | |||
| def split_on_token(tok, text): | |||
| result = [] | |||
| split_text = text.split(tok) | |||
| for i, sub_text in enumerate(split_text): | |||
| sub_text = sub_text.strip() | |||
| if i == 0 and not sub_text: | |||
| result += [tok] | |||
| elif i == len(split_text) - 1: | |||
| if sub_text: | |||
| result += [sub_text] | |||
| else: | |||
| pass | |||
| else: | |||
| if sub_text: | |||
| result += [sub_text] | |||
| result += [tok] | |||
| return result | |||
| def split_on_tokens(tok_list, text): | |||
| if not text.strip(): | |||
| return [] | |||
| if not tok_list: | |||
| return self._tokenize(text, add_prefix_space=add_prefix_space) | |||
| tokenized_text = [] | |||
| text_list = [text] | |||
| for tok in tok_list: | |||
| tokenized_text = [] | |||
| for sub_text in text_list: | |||
| if sub_text not in self.added_tokens_encoder \ | |||
| and sub_text not in all_special_tokens: | |||
| tokenized_text += split_on_token(tok, sub_text) | |||
| else: | |||
| tokenized_text += [sub_text] | |||
| text_list = tokenized_text | |||
| return list(itertools.chain.from_iterable((self._tokenize(token, add_prefix_space=add_prefix_space) if token not \ | |||
| in self.added_tokens_encoder and token not in all_special_tokens \ | |||
| else [token] for token in tokenized_text))) | |||
| added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens | |||
| tokenized_text = split_on_tokens(added_tokens, text) | |||
| return tokenized_text | |||
| def convert_tokens_to_ids(self, tokens): | |||
| """ Converts a single token, or a sequence of tokens, (str) in a single integer id | |||
| (resp. a sequence of ids), using the vocabulary. | |||
| """ | |||
| if tokens is None: | |||
| return None | |||
| if isinstance(tokens, str): | |||
| return self._convert_token_to_id_with_added_voc(tokens) | |||
| ids = [] | |||
| for token in tokens: | |||
| ids.append(self._convert_token_to_id_with_added_voc(token)) | |||
| return ids | |||
| def _convert_token_to_id_with_added_voc(self, token): | |||
| if token is None: | |||
| return None | |||
| if token in self.added_tokens_encoder: | |||
| return self.added_tokens_encoder[token] | |||
| return self._convert_token_to_id(token) | |||
| def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | |||
| """ Converts a single index or a sequence of indices (integers) in a token " | |||
| (resp.) a sequence of tokens (str), using the vocabulary and added tokens. | |||
| Args: | |||
| skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False | |||
| """ | |||
| if isinstance(ids, int): | |||
| return self._convert_id_to_token(ids) | |||
| tokens = [] | |||
| for index in ids: | |||
| index = int(index) | |||
| if skip_special_tokens and index in self.all_special_ids: | |||
| continue | |||
| tokens.append(self._convert_id_to_token(index)) | |||
| return tokens | |||
| def convert_id_to_tokens(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): | |||
| """ | |||
| Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary | |||
| with options to remove special tokens and clean up tokenization spaces. | |||
| Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. | |||
| Args: | |||
| token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. | |||
| skip_special_tokens: if set to True, will replace special tokens. | |||
| clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. | |||
| """ | |||
| filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) | |||
| # To avoid mixing byte-level and unicode for byte-level BPT | |||
| # we need to build string separatly for added tokens and byte-level tokens | |||
| # cf. https://github.com/huggingface/transformers/issues/1133 | |||
| sub_texts = [] | |||
| current_sub_text = [] | |||
| for token in filtered_tokens: | |||
| if skip_special_tokens and token in self.all_special_ids: | |||
| continue | |||
| if token in self.added_tokens_encoder: | |||
| if current_sub_text: | |||
| sub_texts.append(self.convert_tokens_to_string(current_sub_text)) | |||
| current_sub_text = [] | |||
| sub_texts.append(token) | |||
| else: | |||
| current_sub_text.append(token) | |||
| if current_sub_text: | |||
| sub_texts.append(self.convert_tokens_to_string(current_sub_text)) | |||
| text = " ".join(sub_texts) | |||
| if clean_up_tokenization_spaces: | |||
| clean_text = self.clean_up_tokenization(text) | |||
| return clean_text | |||
| else: | |||
| return text | |||
| @property | |||
| def special_tokens_map(self): | |||
| """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their | |||
| values ('<unk>', '<cls>'...) | |||
| """ | |||
| set_attr = {} | |||
| for attr in self.SPECIAL_TOKENS_ATTRIBUTES: | |||
| attr_value = getattr(self, "_" + attr) | |||
| if attr_value: | |||
| set_attr[attr] = attr_value | |||
| return set_attr | |||
| @property | |||
| def all_special_tokens(self): | |||
| """ List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes | |||
| (cls_token, unk_token...). | |||
| """ | |||
| all_toks = [] | |||
| set_attr = self.special_tokens_map | |||
| for attr_value in set_attr.values(): | |||
| all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value]) | |||
| all_toks = list(set(all_toks)) | |||
| return all_toks | |||
| @property | |||
| def all_special_ids(self): | |||
| """ List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to | |||
| class attributes (cls_token, unk_token...). | |||
| """ | |||
| all_toks = self.all_special_tokens | |||
| all_ids = self.convert_tokens_to_ids(all_toks) | |||
| return all_ids | |||
| @staticmethod | |||
| def clean_up_tokenization(out_string): | |||
| """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. | |||
| """ | |||
| out_string = ( | |||
| out_string.replace(" .", ".") | |||
| .replace(" ?", "?") | |||
| .replace(" !", "!") | |||
| .replace(" ,", ",") | |||
| .replace(" ' ", "'") | |||
| .replace(" n't", "n't") | |||
| .replace(" 'm", "'m") | |||
| .replace(" do not", " don't") | |||
| .replace(" 's", "'s") | |||
| .replace(" 've", "'ve") | |||
| .replace(" 're", "'re") | |||
| ) | |||
| return out_string | |||
| @@ -0,0 +1,357 @@ | |||
| from typing import List, Optional | |||
| import json | |||
| import torch | |||
| import torch.nn as nn | |||
| from .bert import BertEmbeddings, BertModel, BertConfig, _get_bert_dir | |||
| from .gpt2 import GPT2Tokenizer | |||
| from ..utils import create_position_ids_from_input_ids, _get_file_name_base_on_postfix | |||
| from ...core import logger | |||
| PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = { | |||
| "roberta-base": 512, | |||
| "roberta-large": 512, | |||
| "roberta-large-mnli": 512, | |||
| "distilroberta-base": 512, | |||
| "roberta-base-openai-detector": 512, | |||
| "roberta-large-openai-detector": 512, | |||
| } | |||
| class RobertaEmbeddings(BertEmbeddings): | |||
| """ | |||
| Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. | |||
| """ | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.padding_idx = 1 | |||
| self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) | |||
| self.position_embeddings = nn.Embedding( | |||
| config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx | |||
| ) | |||
| def forward(self, input_ids=None, token_type_ids=None, position_ids=None, words_embeddings=None): | |||
| if position_ids is None: | |||
| if input_ids is not None: | |||
| # Create the position ids from the input token ids. Any padded tokens remain padded. | |||
| position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) | |||
| else: | |||
| position_ids = self.create_position_ids_from_inputs_embeds(words_embeddings) | |||
| return super().forward( | |||
| input_ids, token_type_ids=token_type_ids, position_ids=position_ids, words_embeddings=words_embeddings | |||
| ) | |||
| def create_position_ids_from_inputs_embeds(self, inputs_embeds): | |||
| """ | |||
| :param torch.Tensor inputs_embeds: | |||
| :return torch.Tensor: | |||
| """ | |||
| input_shape = inputs_embeds.size()[:-1] | |||
| sequence_length = input_shape[1] | |||
| position_ids = torch.arange( | |||
| self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device | |||
| ) | |||
| return position_ids.unsqueeze(0).expand(input_shape) | |||
| class RobertaModel(BertModel): | |||
| r""" | |||
| undocumented | |||
| """ | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.embeddings = RobertaEmbeddings(config) | |||
| self.apply(self.init_bert_weights) | |||
| def get_input_embeddings(self): | |||
| return self.embeddings.word_embeddings | |||
| def set_input_embeddings(self, value): | |||
| self.embeddings.word_embeddings = value | |||
| @classmethod | |||
| def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||
| state_dict = kwargs.get('state_dict', None) | |||
| kwargs.pop('state_dict', None) | |||
| kwargs.pop('cache_dir', None) | |||
| kwargs.pop('from_tf', None) | |||
| # get model dir from name or dir | |||
| pretrained_model_dir = _get_bert_dir(model_dir_or_name) | |||
| # Load config | |||
| config_file = _get_file_name_base_on_postfix(pretrained_model_dir, 'config.json') | |||
| config = BertConfig.from_json_file(config_file) | |||
| # Load model | |||
| if state_dict is None: | |||
| weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') | |||
| state_dict = torch.load(weights_path, map_location='cpu') | |||
| else: | |||
| logger.error(f'Cannot load parameters through `state_dict` variable.') | |||
| raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | |||
| # Instantiate model. | |||
| model = cls(config, *inputs, **kwargs) | |||
| missing_keys = [] | |||
| unexpected_keys = [] | |||
| error_msgs = [] | |||
| # Convert old format to new format if needed from a PyTorch state_dict | |||
| old_keys = [] | |||
| new_keys = [] | |||
| for key in state_dict.keys(): | |||
| new_key = None | |||
| if "gamma" in key: | |||
| new_key = key.replace("gamma", "weight") | |||
| if "beta" in key: | |||
| new_key = key.replace("beta", "bias") | |||
| if new_key: | |||
| old_keys.append(key) | |||
| new_keys.append(new_key) | |||
| for old_key, new_key in zip(old_keys, new_keys): | |||
| state_dict[new_key] = state_dict.pop(old_key) | |||
| # copy state_dict so _load_from_state_dict can modify it | |||
| metadata = getattr(state_dict, "_metadata", None) | |||
| state_dict = state_dict.copy() | |||
| if metadata is not None: | |||
| state_dict._metadata = metadata | |||
| # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |||
| # so we need to apply the function recursively. | |||
| def load(module: nn.Module, prefix=""): | |||
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||
| module._load_from_state_dict( | |||
| state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, | |||
| ) | |||
| for name, child in module._modules.items(): | |||
| if child is not None: | |||
| load(child, prefix + name + ".") | |||
| # Make sure we are able to load base models as well as derived models (with heads) | |||
| start_prefix = "" | |||
| model_to_load = model | |||
| if not hasattr(model, 'roberta') and any( | |||
| s.startswith('roberta') for s in state_dict.keys() | |||
| ): | |||
| start_prefix = 'roberta.' | |||
| if hasattr(model, 'roberta') and not any( | |||
| s.startswith('roberta') for s in state_dict.keys() | |||
| ): | |||
| model_to_load = getattr(model, 'roberta') | |||
| load(model_to_load, prefix=start_prefix) | |||
| if model.__class__.__name__ != model_to_load.__class__.__name__: | |||
| base_model_state_dict = model_to_load.state_dict().keys() | |||
| head_model_state_dict_without_base_prefix = [ | |||
| key.split('roberta.')[-1] for key in model.state_dict().keys() | |||
| ] | |||
| missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) | |||
| if len(missing_keys) > 0: | |||
| logger.info( | |||
| "Weights of {} not initialized from pretrained model: {}".format( | |||
| model.__class__.__name__, missing_keys | |||
| ) | |||
| ) | |||
| if len(unexpected_keys) > 0: | |||
| logger.info( | |||
| "Weights from pretrained model not used in {}: {}".format( | |||
| model.__class__.__name__, unexpected_keys | |||
| ) | |||
| ) | |||
| if len(error_msgs) > 0: | |||
| raise RuntimeError( | |||
| "Error(s) in loading state_dict for {}:\n\t{}".format( | |||
| model.__class__.__name__, "\n\t".join(error_msgs) | |||
| ) | |||
| ) | |||
| # Set model in evaluation mode to desactivate DropOut modules by default | |||
| model.eval() | |||
| logger.info(f"Load pre-trained RoBERTa parameters from file {weights_path}.") | |||
| return model | |||
| class RobertaTokenizer(GPT2Tokenizer): | |||
| vocab_files_names = { | |||
| "vocab_file": "vocab.json", | |||
| "merges_file": "merges.txt", | |||
| } | |||
| def __init__( | |||
| self, | |||
| vocab_file, | |||
| merges_file, | |||
| errors="replace", | |||
| bos_token="<s>", | |||
| eos_token="</s>", | |||
| sep_token="</s>", | |||
| cls_token="<s>", | |||
| unk_token="<unk>", | |||
| pad_token="<pad>", | |||
| mask_token="<mask>", | |||
| **kwargs | |||
| ): | |||
| super().__init__( | |||
| vocab_file=vocab_file, | |||
| merges_file=merges_file, | |||
| errors=errors, | |||
| bos_token=bos_token, | |||
| eos_token=eos_token, | |||
| unk_token=unk_token, | |||
| sep_token=sep_token, | |||
| cls_token=cls_token, | |||
| pad_token=pad_token, | |||
| mask_token=mask_token, | |||
| **kwargs, | |||
| ) | |||
| self.max_len_single_sentence = self.max_len - 2 # take into account special tokens | |||
| self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens | |||
| def build_inputs_with_special_tokens( | |||
| self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None | |||
| ) -> List[int]: | |||
| """ | |||
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks | |||
| by concatenating and adding special tokens. | |||
| A RoBERTa sequence has the following format: | |||
| - single sequence: ``<s> X </s>`` | |||
| - pair of sequences: ``<s> A </s></s> B </s>`` | |||
| Args: | |||
| token_ids_0 (:obj:`List[int]`): | |||
| List of IDs to which the special tokens will be added | |||
| token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. | |||
| """ | |||
| if token_ids_1 is None: | |||
| return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| sep = [self.sep_token_id] | |||
| return cls + token_ids_0 + sep + sep + token_ids_1 + sep | |||
| def get_special_tokens_mask( | |||
| self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False | |||
| ) -> List[int]: | |||
| """ | |||
| Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding | |||
| special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. | |||
| Args: | |||
| token_ids_0 (:obj:`List[int]`): | |||
| List of ids. | |||
| token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): | |||
| Optional second list of IDs for sequence pairs. | |||
| already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
| Set to True if the token list is already formatted with special tokens for the model | |||
| Returns: | |||
| :obj:`List[int]`: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. | |||
| """ | |||
| if already_has_special_tokens: | |||
| if token_ids_1 is not None: | |||
| raise ValueError( | |||
| "You should not supply a second sequence if the provided sequence of " | |||
| "ids is already formated with special tokens for the model." | |||
| ) | |||
| return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) | |||
| if token_ids_1 is None: | |||
| return [1] + ([0] * len(token_ids_0)) + [1] | |||
| return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] | |||
| def create_token_type_ids_from_sequences( | |||
| self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None | |||
| ) -> List[int]: | |||
| """ | |||
| Creates a mask from the two sequences passed to be used in a sequence-pair classification task. | |||
| RoBERTa does not make use of token type ids, therefore a list of zeros is returned. | |||
| Args: | |||
| token_ids_0 (:obj:`List[int]`): | |||
| List of ids. | |||
| token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| :obj:`List[int]`: List of zeros. | |||
| """ | |||
| sep = [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| if token_ids_1 is None: | |||
| return len(cls + token_ids_0 + sep) * [0] | |||
| return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] | |||
| def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs): | |||
| if "add_prefix_space" in kwargs: | |||
| add_prefix_space = kwargs["add_prefix_space"] | |||
| else: | |||
| add_prefix_space = add_special_tokens | |||
| if add_prefix_space and not text[0].isspace(): | |||
| text = " " + text | |||
| return text | |||
| @classmethod | |||
| def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||
| """ | |||
| :param str model_dir_or_name: 目录或者缩写名 | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| # 它需要两个文件,第一个是vocab.json,第二个是merge_file? | |||
| model_dir = _get_bert_dir(model_dir_or_name) | |||
| # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin | |||
| tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json') | |||
| with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: | |||
| init_kwargs = json.load(tokenizer_config_handle) | |||
| # Set max length if needed | |||
| if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES: | |||
| # if we're using a pretrained model, ensure the tokenizer | |||
| # wont index sequences longer than the number of positional embeddings | |||
| max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name] | |||
| if max_len is not None and isinstance(max_len, (int, float)): | |||
| init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len) | |||
| # 将vocab, merge加入到init_kwargs中 | |||
| if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表 | |||
| init_kwargs['vocab_file'] = kwargs['vocab_file'] | |||
| else: | |||
| init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, 'vocab.json') | |||
| init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, 'merges.txt') | |||
| init_inputs = init_kwargs.pop("init_inputs", ()) | |||
| # Instantiate tokenizer. | |||
| try: | |||
| tokenizer = cls(*init_inputs, **init_kwargs) | |||
| except OSError: | |||
| OSError( | |||
| "Unable to load vocabulary from file. " | |||
| "Please check that the provided vocabulary is accessible and not corrupted." | |||
| ) | |||
| return tokenizer | |||
| @@ -148,3 +148,14 @@ def _get_file_name_base_on_postfix(dir_path, postfix): | |||
| elif len(files) > 1: | |||
| raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") | |||
| return os.path.join(dir_path, files[0]) | |||
| def create_position_ids_from_input_ids(input_ids, padding_idx=0): | |||
| r""" Replace non-padding symbols with their position numbers. Position numbers begin at | |||
| padding_idx+1. Padding symbols are ignored. This is modified from fairseq's | |||
| `utils.make_positions`. | |||
| """ | |||
| # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. | |||
| mask = input_ids.ne(padding_idx).int() | |||
| incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask | |||
| return incremental_indicies.long() + padding_idx | |||