diff --git a/fastNLP/embeddings/__init__.py b/fastNLP/embeddings/__init__.py
index ad0ef9c7..3b3b2dce 100644
--- a/fastNLP/embeddings/__init__.py
+++ b/fastNLP/embeddings/__init__.py
@@ -22,6 +22,7 @@ from .embedding import Embedding, TokenEmbedding
from .static_embedding import StaticEmbedding
from .elmo_embedding import ElmoEmbedding
from .bert_embedding import BertEmbedding, BertWordPieceEncoder
+from .roberta_embedding import RobertaEmbedding
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding
from .stack_embedding import StackEmbedding
from .utils import get_embeddings
diff --git a/fastNLP/embeddings/roberta_embedding.py b/fastNLP/embeddings/roberta_embedding.py
new file mode 100644
index 00000000..46b4ebb2
--- /dev/null
+++ b/fastNLP/embeddings/roberta_embedding.py
@@ -0,0 +1,339 @@
+
+import os
+import collections
+import warnings
+from itertools import chain
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .contextual_embedding import ContextualEmbedding
+from ..core import logger, Vocabulary
+from ..modules.encoder.roberta import RobertaModel, RobertaTokenizer
+
+
+class RobertaEmbedding(ContextualEmbedding):
+ r"""
+ 使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于
+ 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有RobertaEmbedding在输入word
+ 时切分),在分割之后长度可能会超过最大长度限制。
+
+ RobertaEmbedding可以支持自动下载权重,当前支持的模型:
+ ..TODO
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import RobertaEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = RobertaEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1')
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ >>> # torch.Size([1, 5, 2304])
+ """
+
+ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1',
+ pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
+ pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs):
+ r"""
+
+ :param ~fastNLP.Vocabulary vocab: 词表
+ :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件
+ (以vocab.json作为后缀名), 权重文件(以.bin作为文件后缀名), 配置文件(以config.json作为后缀名)。
+ :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
+ 从0开始,可以以负数去索引倒数几层。
+ :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
+ 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
+ 会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的
+ embedding长度不匹配。
+ :param bool pooled_cls: 返回的是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取做预测,
+ 一般该值为True。
+ :param bool requires_grad: 是否需要gradient以更新Bert的权重。
+ :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
+ word pieces后的内容,并将第512个word piece置为。超过长度的部分的encode结果直接全部置零。一般仅有只使用
+ 来进行分类的任务将auto_truncate置为True。
+ :param kwargs:
+ bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新
+ 建议设置为True。
+ """
+ super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ if word_dropout > 0:
+ assert vocab.unknown is not None, "When word_drop > 0, Vocabulary must contain the unknown token."
+
+ self._word_sep_index = None
+ if '' in vocab:
+ self._word_sep_index = vocab['']
+
+ only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
+
+ self.model = _WordRobertaModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
+ pool_method=pool_method, include_cls_sep=include_cls_sep,
+ pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2,
+ only_use_pretrain_bpe=only_use_pretrain_bpe)
+ self._sep_index = self.model._sep_index
+ self._cls_index = self.model._cls_index
+ self.requires_grad = requires_grad
+ self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
+
+ def _delete_model_weights(self):
+ del self.model
+
+ def forward(self, words):
+ r"""
+ 计算words的roberta embedding表示。计算之前会在每句话的开始增加在结束增加, 并根据include_cls_sep判断要不要
+ 删除这两个token的表示。
+
+ :param torch.LongTensor words: [batch_size, max_len]
+ :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
+ """
+ words = self.drop_word(words)
+ outputs = self._get_sent_reprs(words)
+ if outputs is not None:
+ return self.dropout(outputs)
+ outputs = self.model(words)
+ outputs = torch.cat([*outputs], dim=-1)
+
+ return self.dropout(outputs)
+
+ def drop_word(self, words):
+ r"""
+ 按照设定随机将words设置为unknown_index。
+
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ with torch.no_grad():
+ not_sep_mask = words.ne(self._sep_index)
+ not_cls_mask = words.ne(self._cls_index)
+ if self._word_sep_index:
+ not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index))
+ replaceable_mask = not_sep_mask.__and__(not_cls_mask)
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ pad_mask = words.ne(self._word_pad_index)
+ mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk
+ words = words.masked_fill(mask, self._word_unk_index)
+ return words
+
+
+class _WordRobertaModel(nn.Module):
+ def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
+ include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
+ only_use_pretrain_bpe=False):
+ super().__init__()
+
+ self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name)
+ self.encoder = RobertaModel.from_pretrained(model_dir_or_name)
+ self._max_position_embeddings = self.encoder.config.max_position_embeddings
+ # 检查encoder_layer_number是否合理
+ encoder_layer_number = len(self.encoder.encoder.layer)
+ self.layers = list(map(int, layers.split(',')))
+ for layer in self.layers:
+ if layer < 0:
+ assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a roberta model with {encoder_layer_number} layers."
+ else:
+ assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a roberta model with {encoder_layer_number} layers."
+
+ assert pool_method in ('avg', 'max', 'first', 'last')
+ self.pool_method = pool_method
+ self.include_cls_sep = include_cls_sep
+ self.pooled_cls = pooled_cls
+ self.auto_truncate = auto_truncate
+
+ # 将所有vocab中word的wordpiece计算出来, 需要额外考虑和
+ logger.info("Start to generate word pieces for word.")
+ # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
+ word_piece_dict = {'': 1, '': 1} # 用到的word_piece以及新增的
+ found_count = 0
+ self._has_sep_in_vocab = '' in vocab # 用来判断传入的数据是否需要生成token_ids
+ if "" in vocab:
+ warnings.warn(" detected in your vocabulary. RobertaEmbedding will add and to the begin "
+ "and end of the input automatically, make sure you don't add and at the begin"
+ " and end.")
+ for word, index in vocab:
+ if index == vocab.padding_idx: # pad是个特殊的符号
+ word = ''
+ elif index == vocab.unknown_idx:
+ word = ''
+ # _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容
+ word_pieces = []
+ word_pieces.extend(self.tokenzier.tokenize(word))
+ if len(word_pieces) == 1:
+ if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
+ if index != vocab.unknown_idx and word_pieces[0] == '': # 说明这个词不在原始的word里面
+ if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
+ word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增
+ word_piece_dict[word] = 1 # 新增一个值
+ continue
+ for word_piece in word_pieces:
+ word_piece_dict[word_piece] = 1
+ found_count += 1
+ original_embed = self.encoder.embeddings.word_embeddings.weight.data
+ # 特殊词汇要特殊处理
+ embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed
+ new_word_piece_vocab = collections.OrderedDict()
+ for index, token in enumerate(['', '']):
+ word_piece_dict.pop(token, None)
+ embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]]
+ new_word_piece_vocab[token] = index
+ for token in word_piece_dict.keys():
+ if token in self.tokenzier.encoder:
+ embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.encoder[token]]
+ else:
+ embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.encoder['']]
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ self._reinit_on_new_vocab(new_word_piece_vocab, model_dir_or_name)
+ self.encoder.embeddings.word_embeddings = embed
+
+ word_to_wordpieces = []
+ word_pieces_lengths = []
+ for word, index in vocab:
+ if index == vocab.padding_idx: # pad是个特殊的符号
+ word = ''
+ elif index == vocab.unknown_idx:
+ word = ''
+ word_pieces = self.tokenzier.tokenize(word)
+ word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces)
+ word_to_wordpieces.append(word_pieces)
+ word_pieces_lengths.append(len(word_pieces))
+ self._cls_index = self.tokenzier.encoder['']
+ self._sep_index = self.tokenzier.encoder['']
+ self._word_pad_index = vocab.padding_idx
+ self._wordpiece_pad_index = self.tokenzier.encoder[''] # 需要用于生成word_piece
+ logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
+ self.word_to_wordpieces = np.array(word_to_wordpieces)
+ self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
+ logger.debug("Successfully generate word pieces.")
+
+ def _reinit_on_new_vocab(self, vocab, model_dir_or_name):
+ import json
+ with open('./.tmp-new-vocab-file.json', 'w') as f:
+ json.dump(vocab, f)
+ self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name, vocab_file='./.tmp-new-vocab-file.json')
+ os.remove('./.tmp-new-vocab-file.json')
+
+ def forward(self, words):
+ r"""
+
+ :param words: torch.LongTensor, batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
+ """
+ with torch.no_grad():
+ batch_size, max_word_len = words.size()
+ word_mask = words.ne(self._word_pad_index) # 为1的地方有word
+ seq_len = word_mask.sum(dim=-1)
+ batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
+ 0) # batch_size x max_len
+ word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
+ word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
+ if word_piece_length + 2 > self._max_position_embeddings:
+ if self.auto_truncate:
+ word_pieces_lengths = word_pieces_lengths.masked_fill(
+ word_pieces_lengths + 2 > self._max_position_embeddings,
+ self._max_position_embeddings - 2)
+ else:
+ raise RuntimeError(
+ "After split words into word pieces, the lengths of word pieces are longer than the "
+ f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
+ f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
+
+ # +2是由于需要加入与
+ word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
+ fill_value=self._wordpiece_pad_index)
+ attn_masks = torch.zeros_like(word_pieces)
+ # 1. 获取words的word_pieces的id,以及对应的span范围
+ word_indexes = words.cpu().numpy()
+ for i in range(batch_size):
+ word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
+ if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2:
+ word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
+ word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
+ attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
+ # 添加[cls]和[sep]
+ word_pieces[:, 0].fill_(self._cls_index)
+ batch_indexes = torch.arange(batch_size).to(words)
+ word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
+ # if self._has_sep_in_vocab: # 但在vocab中出现应该才会需要token_ids
+ # sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len
+ # sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
+ # token_type_ids = sep_mask_cumsum.fmod(2)
+ # if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
+ # token_type_ids = token_type_ids.eq(0).long()
+ # else: # RoBERTa不需要额外设置token_type_ids
+ token_type_ids = torch.zeros_like(word_pieces)
+ # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
+ # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
+ bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids,
+ attention_mask=attn_masks,
+ output_all_encoded_layers=True)
+ # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
+
+ if self.include_cls_sep:
+ s_shift = 1
+ outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
+ bert_outputs[-1].size(-1))
+
+ else:
+ s_shift = 0
+ outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
+ bert_outputs[-1].size(-1))
+ batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1)
+ batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
+
+ if self.pool_method == 'first':
+ batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
+ _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
+ elif self.pool_method == 'last':
+ batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max() + 1] - 1
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
+ _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
+
+ for l_index, l in enumerate(self.layers):
+ output_layer = bert_outputs[l]
+ real_word_piece_length = output_layer.size(1) - 2
+ if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
+ paddings = output_layer.new_zeros(batch_size,
+ word_piece_length - real_word_piece_length,
+ output_layer.size(2))
+ output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
+ # 从word_piece collapse到word的表示
+ truncate_output_layer = output_layer[:, 1:-1] # 删除与 batch_size x len x hidden_size
+ if self.pool_method == 'first':
+ tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
+ tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
+ outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp
+
+ elif self.pool_method == 'last':
+ tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
+ tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
+ outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp
+ elif self.pool_method == 'max':
+ for i in range(batch_size):
+ for j in range(seq_len[i]):
+ start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
+ outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2)
+ else:
+ for i in range(batch_size):
+ for j in range(seq_len[i]):
+ start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
+ outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
+ if self.include_cls_sep:
+ if l in (len(bert_outputs) - 1, -1) and self.pooled_cls:
+ outputs[l_index, :, 0] = pooled_cls
+ else:
+ outputs[l_index, :, 0] = output_layer[:, 0]
+ outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
+
+ # 3. 最终的embedding结果
+ return outputs
diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py
index 57ed5d6c..3c9af22d 100644
--- a/fastNLP/modules/encoder/__init__.py
+++ b/fastNLP/modules/encoder/__init__.py
@@ -34,6 +34,7 @@ __all__ = [
from .attention import MultiHeadAttention, BiAttention, SelfAttention
from .bert import BertModel
+from .roberta import RobertaModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
from .lstm import LSTM
diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py
index 3496c5f6..32edafbe 100644
--- a/fastNLP/modules/encoder/bert.py
+++ b/fastNLP/modules/encoder/bert.py
@@ -245,14 +245,18 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, input_ids, token_type_ids=None):
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, words_embeddings=None):
seq_length = input_ids.size(1)
- position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+ if position_ids is None:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
- words_embeddings = self.word_embeddings(input_ids)
+ if words_embeddings is None:
+ words_embeddings = self.word_embeddings(input_ids)
+ else:
+ assert input_ids.size() == words_embeddings.size()[: -1]
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
diff --git a/fastNLP/modules/encoder/gpt2.py b/fastNLP/modules/encoder/gpt2.py
new file mode 100644
index 00000000..5b692253
--- /dev/null
+++ b/fastNLP/modules/encoder/gpt2.py
@@ -0,0 +1,773 @@
+
+from functools import lru_cache
+import json
+import regex as re
+import itertools
+
+
+from ...io.file_utils import _get_embedding_url, cached_path
+from ...core import logger
+
+import os
+
+PRETRAINED_GPT2_MODEL_DIR = PRETRAINED_BERT_MODEL_DIR = {
+ 'en-small': 'gpt2-small.zip',
+ 'en-median': 'gpt2-medium.zip',
+ 'en': 'gpt2-medium.zip'
+}
+
+
+def _get_gpt2_dir(model_dir_or_name: str = 'en-median'):
+ if model_dir_or_name.lower() in PRETRAINED_GPT2_MODEL_DIR:
+ model_url = _get_embedding_url('gpt2', model_dir_or_name.lower())
+ model_dir = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
+ else:
+ logger.error(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.")
+ raise ValueError(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.")
+ return str(model_dir)
+
+
+def _get_filepath_based_on_postfix(folder, postfix):
+ """
+ 在folder下寻找结尾为postfix的文件. 比如寻找结尾为vocab.txt的文件。只会匹配第一个,如果有多个不会报错,没有找到会报错。
+ 返回该文件的全路径
+
+ :param str folder:
+ :param str postfix:
+ :return:
+ """
+ for filename in os.listdir(folder):
+ if os.path.isfile(os.path.join(folder, filename)):
+ if filename.endswith(postfix):
+ return os.path.join(folder, filename)
+ raise FileNotFoundError(f"File {postfix} is not found in {folder}.")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings.
+ We specifically avoids mapping to whitespace/control characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2 ** 8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2 ** 8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
+ "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
+ "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
+ "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
+ "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
+ },
+ "merges_file": {
+ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
+ "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
+ "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
+ "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
+ "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
+ },
+}
+
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "en-small": 1024,
+ 'en': 1024,
+ "en-medium": 1024,
+ "en-large": 1024,
+ "en-xl": 1024,
+ "en-distilgpt2": 1024,
+}
+
+PATTERN = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+
+def gpt2_tokenize(text, add_prefix_space=True):
+ """
+
+ :param str text:
+ :param bool add_prefix_space: 是否在句子前面加上space,如果加上才能保证与GPT2训练时一致
+ :return: []
+ """
+ if text is '':
+ return []
+ if add_prefix_space:
+ text = ' ' + text
+ tokens = []
+ for token in re.findall(PATTERN, text):
+ tokens.append(token)
+ return tokens
+
+
+class GPT2Tokenizer:
+ """
+ GPT-2 BPE tokenizer. Peculiarities:
+ - Byte-level Byte-Pair-Encoding
+ - Requires a space to start the input string => the encoding and tokenize methods should be called with the
+ ``add_prefix_space`` flag set to ``True``.
+ Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve
+ the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"`
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+
+ SPECIAL_TOKENS_ATTRIBUTES = [
+ "bos_token",
+ "eos_token",
+ "unk_token",
+ "pad_token",
+ "cls_token",
+ "mask_token",
+ ]
+
+ padding_side = "right"
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ **kwargs
+ ):
+ self._bos_token = None
+ self._eos_token = None
+ self._unk_token = None
+ self._sep_token = None
+ self._pad_token = None
+ self._cls_token = None
+ self._mask_token = None
+ self._pad_token_type_id = 0
+
+ self.bos_token = bos_token
+ self.eos_token = eos_token
+ self.unk_token = unk_token
+
+ self.max_len = int(1e12)
+ self.padding_side = kwargs.pop("padding_side", self.padding_side)
+ self.added_tokens_encoder = {}
+ self.unique_added_tokens_encoder = set()
+ self.added_tokens_decoder = {}
+ # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
+ self.init_inputs = ()
+ self.init_kwargs = {}
+
+ for key, value in kwargs.items():
+ if key in self.SPECIAL_TOKENS_ATTRIBUTES:
+ if key == "additional_special_tokens":
+ assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
+ else:
+ assert isinstance(value, str)
+ setattr(self, key, value)
+
+ self.max_len_single_sentence = (
+ self.max_len
+ ) # no default special tokens - you can update this value if you add special tokens
+ self.max_len_sentences_pair = (
+ self.max_len
+ ) # no default special tokens - you can update this value if you add special tokens
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+
+ def add_special_tokens(self, special_tokens_dict):
+ """
+ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
+ to class attributes. If special tokens are NOT in the vocabulary, they are added
+ to it (indexed starting from the last index of the current vocabulary).
+
+ Using `add_special_tokens` will ensure your special tokens can be used in several ways:
+
+ - special tokens are carefully handled by the tokenizer (they are never split)
+ - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts.
+
+ When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '')
+
+ Args:
+ special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
+ [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
+ ``additional_special_tokens``].
+
+ Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
+
+ Returns:
+ Number of tokens added to the vocabulary.
+
+ Examples::
+
+ # Let's see how to add a new classification token to GPT-2
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ model = GPT2Model.from_pretrained('gpt2')
+
+ special_tokens_dict = {'cls_token': ''}
+
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
+ print('We have added', num_added_toks, 'tokens')
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+
+ assert tokenizer.cls_token == ''
+ """
+ if not special_tokens_dict:
+ return 0
+
+ added_tokens = 0
+ for key, value in special_tokens_dict.items():
+ assert key in self.SPECIAL_TOKENS_ATTRIBUTES
+ if key == "additional_special_tokens":
+ assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
+ added_tokens += self.add_tokens(value)
+ else:
+ assert isinstance(value, str)
+ added_tokens += self.add_tokens([value])
+ logger.debug("Assigning %s to the %s key of the tokenizer", value, key)
+ setattr(self, key, value)
+
+ return added_tokens
+
+ def add_tokens(self, new_tokens):
+ """
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
+
+ Args:
+ new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
+
+ Returns:
+ Number of tokens added to the vocabulary.
+
+ Examples::
+
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertModel.from_pretrained('bert-base-uncased')
+
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
+ print('We have added', num_added_toks, 'tokens')
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+ """
+ if not new_tokens:
+ return 0
+
+ to_add_tokens = []
+ for token in new_tokens:
+ assert isinstance(token, str)
+ if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
+ token = token.lower()
+ if (
+ token != self.unk_token
+ and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
+ and token not in to_add_tokens
+ ):
+ to_add_tokens.append(token)
+ logger.debug("Adding %s to the vocabulary", token)
+
+ added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
+ added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
+ self.added_tokens_encoder.update(added_tok_encoder)
+ self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
+ self.added_tokens_decoder.update(added_tok_decoder)
+
+ return len(to_add_tokens)
+
+ @property
+ def bos_token(self):
+ """ Beginning of sentence token (string). Log an error if used while not having been set. """
+ if self._bos_token is None:
+ logger.error("Using bos_token, but it is not set yet.")
+ return self._bos_token
+
+ @property
+ def eos_token(self):
+ """ End of sentence token (string). Log an error if used while not having been set. """
+ if self._eos_token is None:
+ logger.error("Using eos_token, but it is not set yet.")
+ return self._eos_token
+
+ @property
+ def unk_token(self):
+ """ Unknown token (string). Log an error if used while not having been set. """
+ if self._unk_token is None:
+ logger.error("Using unk_token, but it is not set yet.")
+ return self._unk_token
+
+ @property
+ def pad_token(self):
+ """ Padding token (string). Log an error if used while not having been set. """
+ if self._pad_token is None:
+ logger.error("Using pad_token, but it is not set yet.")
+ return self._pad_token
+
+ @property
+ def cls_token(self):
+ """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
+ if self._cls_token is None:
+ logger.error("Using cls_token, but it is not set yet.")
+ return self._cls_token
+
+ @property
+ def mask_token(self):
+ """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
+ if self._mask_token is None:
+ logger.error("Using mask_token, but it is not set yet.")
+ return self._mask_token
+
+ @bos_token.setter
+ def bos_token(self, value):
+ self._bos_token = value
+
+ @eos_token.setter
+ def eos_token(self, value):
+ self._eos_token = value
+
+ @unk_token.setter
+ def unk_token(self, value):
+ self._unk_token = value
+
+ @pad_token.setter
+ def pad_token(self, value):
+ self._pad_token = value
+
+ @cls_token.setter
+ def cls_token(self, value):
+ self._cls_token = value
+
+ @mask_token.setter
+ def mask_token(self, value):
+ self._mask_token = value
+
+ @property
+ def bos_token_id(self):
+ """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.bos_token)
+
+ @property
+ def eos_token_id(self):
+ """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.eos_token)
+
+ @property
+ def unk_token_id(self):
+ """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.unk_token)
+
+ @property
+ def pad_token_id(self):
+ """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.pad_token)
+
+ @property
+ def pad_token_type_id(self):
+ """ Id of the padding token type in the vocabulary."""
+ return self._pad_token_type_id
+
+ @property
+ def cls_token_id(self):
+ """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.cls_token)
+
+ @property
+ def mask_token_id(self):
+ """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.mask_token)
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text, add_prefix_space=False):
+ """ Tokenize a string.
+ Args:
+ - add_prefix_space (boolean, default False):
+ Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
+ """
+ bpe_tokens = []
+ for token in gpt2_tokenize(text, add_prefix_space=add_prefix_space):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory):
+ """Save the tokenizer vocabulary and merge files to a directory."""
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+ merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, ensure_ascii=False))
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ "Saving vocabulary to {}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(merge_file)
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ @classmethod
+ def from_pretrained(cls, model_dir_or_name):
+ r"""
+ """
+ return cls._from_pretrained(model_dir_or_name)
+
+ # 将它修改一定传入文件夹
+ @classmethod
+ def _from_pretrained(cls, model_dir_or_name):
+ """
+
+ :param str model_dir_or_name: 目录或者缩写名
+ :param init_inputs:
+ :param kwargs:
+ :return:
+ """
+ # 它需要两个文件,第一个是vocab.json,第二个是merge_file?
+ model_dir = _get_gpt2_dir(model_dir_or_name)
+ # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin
+
+ tokenizer_config_file = _get_filepath_based_on_postfix(model_dir, 'config.json')
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
+ init_kwargs = json.load(tokenizer_config_handle)
+ # Set max length if needed
+ if model_dir_or_name in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
+ # if we're using a pretrained model, ensure the tokenizer
+ # wont index sequences longer than the number of positional embeddings
+ max_len = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name]
+ if max_len is not None and isinstance(max_len, (int, float)):
+ init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
+
+ # 将vocab, merge加入到init_kwargs中
+ init_kwargs['vocab_file'] = _get_filepath_based_on_postfix(model_dir, 'vocab.json')
+ init_kwargs['merges_file'] = _get_filepath_based_on_postfix(model_dir, 'merges.txt')
+
+ init_inputs = init_kwargs.pop("init_inputs", ())
+ # Instantiate tokenizer.
+ try:
+ tokenizer = cls(*init_inputs, **init_kwargs)
+ except OSError:
+ OSError(
+ "Unable to load vocabulary from file. "
+ "Please check that the provided vocabulary is accessible and not corrupted."
+ )
+
+ return tokenizer
+
+ def __len__(self):
+ """ Size of the full vocabulary with the added tokens """
+ return self.vocab_size + len(self.added_tokens_encoder)
+
+ def tokenize(self, text, add_prefix_space=True):
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
+ Split in words for word-based vocabulary or sub-words for sub-word-based
+ vocabularies (BPE/SentencePieces/WordPieces).
+
+ Take care of added tokens.
+ Args:
+ - text: The sequence to be encoded.
+ - add_prefix_space (boolean, default True):
+ Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
+ """
+ all_special_tokens = self.all_special_tokens
+
+ def lowercase_text(t):
+ # convert non-special tokens to lowercase
+ escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
+ pattern = r'(' + r'|'.join(escaped_special_toks) + r')|' + \
+ r'(.+?)'
+ return re.sub(
+ pattern,
+ lambda m: m.groups()[0] or m.groups()[1].lower(),
+ t)
+
+ if self.init_kwargs.get('do_lower_case', False):
+ text = lowercase_text(text)
+
+ def split_on_token(tok, text):
+ result = []
+ split_text = text.split(tok)
+ for i, sub_text in enumerate(split_text):
+ sub_text = sub_text.strip()
+ if i == 0 and not sub_text:
+ result += [tok]
+ elif i == len(split_text) - 1:
+ if sub_text:
+ result += [sub_text]
+ else:
+ pass
+ else:
+ if sub_text:
+ result += [sub_text]
+ result += [tok]
+ return result
+
+ def split_on_tokens(tok_list, text):
+ if not text.strip():
+ return []
+ if not tok_list:
+ return self._tokenize(text, add_prefix_space=add_prefix_space)
+
+ tokenized_text = []
+ text_list = [text]
+ for tok in tok_list:
+ tokenized_text = []
+ for sub_text in text_list:
+ if sub_text not in self.added_tokens_encoder \
+ and sub_text not in all_special_tokens:
+ tokenized_text += split_on_token(tok, sub_text)
+ else:
+ tokenized_text += [sub_text]
+ text_list = tokenized_text
+
+ return list(itertools.chain.from_iterable((self._tokenize(token, add_prefix_space=add_prefix_space) if token not \
+ in self.added_tokens_encoder and token not in all_special_tokens \
+ else [token] for token in tokenized_text)))
+
+ added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
+ tokenized_text = split_on_tokens(added_tokens, text)
+ return tokenized_text
+
+ def convert_tokens_to_ids(self, tokens):
+ """ Converts a single token, or a sequence of tokens, (str) in a single integer id
+ (resp. a sequence of ids), using the vocabulary.
+ """
+ if tokens is None:
+ return None
+
+ if isinstance(tokens, str):
+ return self._convert_token_to_id_with_added_voc(tokens)
+
+ ids = []
+ for token in tokens:
+ ids.append(self._convert_token_to_id_with_added_voc(token))
+ return ids
+
+ def _convert_token_to_id_with_added_voc(self, token):
+ if token is None:
+ return None
+
+ if token in self.added_tokens_encoder:
+ return self.added_tokens_encoder[token]
+ return self._convert_token_to_id(token)
+
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
+ """ Converts a single index or a sequence of indices (integers) in a token "
+ (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
+
+ Args:
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
+ """
+ if isinstance(ids, int):
+ return self._convert_id_to_token(ids)
+ tokens = []
+ for index in ids:
+ index = int(index)
+ if skip_special_tokens and index in self.all_special_ids:
+ continue
+ tokens.append(self._convert_id_to_token(index))
+ return tokens
+
+ def convert_id_to_tokens(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
+ """
+ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
+ with options to remove special tokens and clean up tokenization spaces.
+ Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
+
+ Args:
+ token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
+ skip_special_tokens: if set to True, will replace special tokens.
+ clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
+ """
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
+
+ # To avoid mixing byte-level and unicode for byte-level BPT
+ # we need to build string separatly for added tokens and byte-level tokens
+ # cf. https://github.com/huggingface/transformers/issues/1133
+ sub_texts = []
+ current_sub_text = []
+ for token in filtered_tokens:
+ if skip_special_tokens and token in self.all_special_ids:
+ continue
+ if token in self.added_tokens_encoder:
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ current_sub_text = []
+ sub_texts.append(token)
+ else:
+ current_sub_text.append(token)
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ text = " ".join(sub_texts)
+
+ if clean_up_tokenization_spaces:
+ clean_text = self.clean_up_tokenization(text)
+ return clean_text
+ else:
+ return text
+
+ @property
+ def special_tokens_map(self):
+ """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
+ values ('', ''...)
+ """
+ set_attr = {}
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+ attr_value = getattr(self, "_" + attr)
+ if attr_value:
+ set_attr[attr] = attr_value
+ return set_attr
+
+ @property
+ def all_special_tokens(self):
+ """ List all the special tokens ('', ''...) mapped to class attributes
+ (cls_token, unk_token...).
+ """
+ all_toks = []
+ set_attr = self.special_tokens_map
+ for attr_value in set_attr.values():
+ all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
+ all_toks = list(set(all_toks))
+ return all_toks
+
+ @property
+ def all_special_ids(self):
+ """ List the vocabulary indices of the special tokens ('', ''...) mapped to
+ class attributes (cls_token, unk_token...).
+ """
+ all_toks = self.all_special_tokens
+ all_ids = self.convert_tokens_to_ids(all_toks)
+ return all_ids
+
+ @staticmethod
+ def clean_up_tokenization(out_string):
+ """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
+ """
+ out_string = (
+ out_string.replace(" .", ".")
+ .replace(" ?", "?")
+ .replace(" !", "!")
+ .replace(" ,", ",")
+ .replace(" ' ", "'")
+ .replace(" n't", "n't")
+ .replace(" 'm", "'m")
+ .replace(" do not", " don't")
+ .replace(" 's", "'s")
+ .replace(" 've", "'ve")
+ .replace(" 're", "'re")
+ )
+ return out_string
diff --git a/fastNLP/modules/encoder/roberta.py b/fastNLP/modules/encoder/roberta.py
new file mode 100644
index 00000000..af8795c6
--- /dev/null
+++ b/fastNLP/modules/encoder/roberta.py
@@ -0,0 +1,357 @@
+
+from typing import List, Optional
+import json
+
+import torch
+import torch.nn as nn
+
+from .bert import BertEmbeddings, BertModel, BertConfig, _get_bert_dir
+from .gpt2 import GPT2Tokenizer
+from ..utils import create_position_ids_from_input_ids, _get_file_name_base_on_postfix
+from ...core import logger
+
+PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = {
+ "roberta-base": 512,
+ "roberta-large": 512,
+ "roberta-large-mnli": 512,
+ "distilroberta-base": 512,
+ "roberta-base-openai-detector": 512,
+ "roberta-large-openai-detector": 512,
+}
+
+
+class RobertaEmbeddings(BertEmbeddings):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = 1
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, words_embeddings=None):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(words_embeddings)
+
+ return super().forward(
+ input_ids, token_type_ids=token_type_ids, position_ids=position_ids, words_embeddings=words_embeddings
+ )
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ :param torch.Tensor inputs_embeds:
+ :return torch.Tensor:
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+class RobertaModel(BertModel):
+ r"""
+ undocumented
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embeddings = RobertaEmbeddings(config)
+ self.apply(self.init_bert_weights)
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ @classmethod
+ def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
+ state_dict = kwargs.get('state_dict', None)
+ kwargs.pop('state_dict', None)
+ kwargs.pop('cache_dir', None)
+ kwargs.pop('from_tf', None)
+
+ # get model dir from name or dir
+ pretrained_model_dir = _get_bert_dir(model_dir_or_name)
+
+ # Load config
+ config_file = _get_file_name_base_on_postfix(pretrained_model_dir, 'config.json')
+ config = BertConfig.from_json_file(config_file)
+
+ # Load model
+ if state_dict is None:
+ weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin')
+ state_dict = torch.load(weights_path, map_location='cpu')
+ else:
+ logger.error(f'Cannot load parameters through `state_dict` variable.')
+ raise RuntimeError(f'Cannot load parameters through `state_dict` variable.')
+
+ # Instantiate model.
+ model = cls(config, *inputs, **kwargs)
+
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+
+ # Convert old format to new format if needed from a PyTorch state_dict
+ old_keys = []
+ new_keys = []
+ for key in state_dict.keys():
+ new_key = None
+ if "gamma" in key:
+ new_key = key.replace("gamma", "weight")
+ if "beta" in key:
+ new_key = key.replace("beta", "bias")
+ if new_key:
+ old_keys.append(key)
+ new_keys.append(new_key)
+ for old_key, new_key in zip(old_keys, new_keys):
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: nn.Module, prefix=""):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
+ )
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ start_prefix = ""
+ model_to_load = model
+ if not hasattr(model, 'roberta') and any(
+ s.startswith('roberta') for s in state_dict.keys()
+ ):
+ start_prefix = 'roberta.'
+ if hasattr(model, 'roberta') and not any(
+ s.startswith('roberta') for s in state_dict.keys()
+ ):
+ model_to_load = getattr(model, 'roberta')
+
+ load(model_to_load, prefix=start_prefix)
+
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
+ base_model_state_dict = model_to_load.state_dict().keys()
+ head_model_state_dict_without_base_prefix = [
+ key.split('roberta.')[-1] for key in model.state_dict().keys()
+ ]
+
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
+
+ if len(missing_keys) > 0:
+ logger.info(
+ "Weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, missing_keys
+ )
+ )
+ if len(unexpected_keys) > 0:
+ logger.info(
+ "Weights from pretrained model not used in {}: {}".format(
+ model.__class__.__name__, unexpected_keys
+ )
+ )
+ if len(error_msgs) > 0:
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
+ model.__class__.__name__, "\n\t".join(error_msgs)
+ )
+ )
+
+ # Set model in evaluation mode to desactivate DropOut modules by default
+ model.eval()
+
+ logger.info(f"Load pre-trained RoBERTa parameters from file {weights_path}.")
+
+ return model
+
+
+class RobertaTokenizer(GPT2Tokenizer):
+
+ vocab_files_names = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+ }
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+ self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
+ self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ A RoBERTa sequence has the following format:
+
+ - single sequence: `` X ``
+ - pair of sequences: `` A B ``
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True if the token list is already formatted with special tokens for the model
+
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
+ """
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
+ RoBERTa does not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of zeros.
+
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs):
+ if "add_prefix_space" in kwargs:
+ add_prefix_space = kwargs["add_prefix_space"]
+ else:
+ add_prefix_space = add_special_tokens
+ if add_prefix_space and not text[0].isspace():
+ text = " " + text
+ return text
+
+ @classmethod
+ def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
+ """
+
+ :param str model_dir_or_name: 目录或者缩写名
+ :param kwargs:
+ :return:
+ """
+ # 它需要两个文件,第一个是vocab.json,第二个是merge_file?
+ model_dir = _get_bert_dir(model_dir_or_name)
+ # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin
+
+ tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json')
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
+ init_kwargs = json.load(tokenizer_config_handle)
+ # Set max length if needed
+ if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES:
+ # if we're using a pretrained model, ensure the tokenizer
+ # wont index sequences longer than the number of positional embeddings
+ max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name]
+ if max_len is not None and isinstance(max_len, (int, float)):
+ init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
+
+ # 将vocab, merge加入到init_kwargs中
+ if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表
+ init_kwargs['vocab_file'] = kwargs['vocab_file']
+ else:
+ init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, 'vocab.json')
+ init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, 'merges.txt')
+
+ init_inputs = init_kwargs.pop("init_inputs", ())
+ # Instantiate tokenizer.
+ try:
+ tokenizer = cls(*init_inputs, **init_kwargs)
+ except OSError:
+ OSError(
+ "Unable to load vocabulary from file. "
+ "Please check that the provided vocabulary is accessible and not corrupted."
+ )
+
+ return tokenizer
+
+
diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py
index 171a332d..79e2a7de 100644
--- a/fastNLP/modules/utils.py
+++ b/fastNLP/modules/utils.py
@@ -148,3 +148,14 @@ def _get_file_name_base_on_postfix(dir_path, postfix):
elif len(files) > 1:
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}")
return os.path.join(dir_path, files[0])
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx=0):
+ r""" Replace non-padding symbols with their position numbers. Position numbers begin at
+ padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
+ `utils.make_positions`.
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask
+ return incremental_indicies.long() + padding_idx