diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py
index 7c1e64ee..94942f09 100644
--- a/fastNLP/core/batch.py
+++ b/fastNLP/core/batch.py
@@ -217,7 +217,8 @@ class BatchIter:
class DataSetIter(BatchIter):
r"""
- DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出,
+ DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出,通过使用DataSetIter,可以不需要考虑
+ 输入的padding(由DataSet中每列的Padder决定了)以及不需要考虑将数据转为tensor。
组成 `x` 和 `y`::
batch = DataSetIter(data_set, batch_size=16, sampler=SequentialSampler())
@@ -226,10 +227,8 @@ class DataSetIter(BatchIter):
# do stuff ...
"""
- def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
- num_workers=0, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None, collate_fn=None,
- batch_sampler=None):
+ def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False,
+ drop_last=False, timeout=0, worker_init_fn=None, batch_sampler=None):
r"""
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
@@ -245,13 +244,12 @@ class DataSetIter(BatchIter):
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
:param timeout: 生成一个batch的timeout值
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
- :param collate_fn: 用于将样本组合成batch的函数
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。
"""
assert isinstance(dataset, DataSet)
dataset = DataSetGetter(dataset, as_numpy)
- collate_fn = dataset.collate_fn if collate_fn is None else collate_fn
+ collate_fn = dataset.collate_fn
if batch_sampler is not None:
batch_size = 1
sampler = None
@@ -272,8 +270,9 @@ class DataSetIter(BatchIter):
class TorchLoaderIter(BatchIter):
r"""
- 与DataSetIter类似,但可以用于非fastNLP的数据容器对象,然后将其传入到Trainer中。
- 只需要保证数据容器实现了实现了以下的方法
+ 与DataSetIter类似,但可以用于非fastNLP的数据容器对象,以及可以实现完全自定义的生成batch的方式,然后与Trainer,Tester可以实现
+ 与DataSetIter一样的对接。
+ 需要保证传入的数据容器实现了实现了以下的方法
Example::
@@ -293,7 +292,7 @@ class TorchLoaderIter(BatchIter):
return self.num_samples
# 需要实现collact_fn将数据转换为tensor
- def collact_fn(data_list):
+ def collate_fn(data_list):
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
xs, ys = [], []
for l in data_list:
@@ -302,10 +301,10 @@ class TorchLoaderIter(BatchIter):
ys.append(y)
# 不需要转移到gpu,Trainer或Tester会将其转移到model所在的device
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys)
- return {'x':x, 'y':y}, {'y':y}
+ return {'x':x, 'y':y}, {'y':y} # 第一个dict中内容类似于DataSet中的input列,第二个dict的内容类似于target列
udf_dataset = UdfDataSet(10)
- dataset = TorchLoaderIter(udf_dataset, collate_fn=collact_fn)
+ dataset = TorchLoaderIter(udf_dataset, collate_fn=collate_fn)
class Model(nn.Module):
def __init__(self):
super().__init__()
@@ -362,7 +361,7 @@ class TorchLoaderIter(BatchIter):
def __len__(self):
return self.num_samples
- def collact_fn(data_list):
+ def collate_fn(data_list):
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
xs, ys = [], []
for l in data_list:
@@ -370,10 +369,10 @@ class TorchLoaderIter(BatchIter):
xs.append(x)
ys.append(y)
x, y = torch.FloatTensor(xs), torch.FloatTensor(ys)
- return {'x': x, 'y': y}, {'y': y}
+ return {'x': x, 'y': y}, {'y': y} # 第一个dict中内容类似于DataSet中的input列,第二个dict的内容类似于target列
file_data = FileDataSet(tmp_file_path)
- dataset = TorchLoaderIter(file_data, collate_fn=collact_fn)
+ dataset = TorchLoaderIter(file_data, collate_fn=collate_fn)
class Model(nn.Module):
def __init__(self):
diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py
index f5c0f229..680c4f80 100644
--- a/fastNLP/core/dist_trainer.py
+++ b/fastNLP/core/dist_trainer.py
@@ -205,11 +205,8 @@ class DistTrainer():
def _get_data_iter(self, dataset):
if isinstance(dataset, DataSet):
- return DataSetIter(
- dataset=dataset, batch_size=self.batch_size_per_gpu,
- num_workers=self.num_data_workers, sampler=self.sampler,
- drop_last=self.drop_last
- )
+ return DataSetIter(dataset=dataset, batch_size=self.batch_size_per_gpu, sampler=self.sampler,
+ num_workers=self.num_data_workers, drop_last=self.drop_last)
elif isinstance(dataset, BatchIter):
return dataset
else:
diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py
index b223d35f..680782b1 100644
--- a/fastNLP/core/tester.py
+++ b/fastNLP/core/tester.py
@@ -107,8 +107,8 @@ class Tester(object):
self.logger = logger
if isinstance(data, DataSet):
- self.data_iterator = DataSetIter(
- dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler())
+ self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=SequentialSampler(),
+ num_workers=num_workers)
elif isinstance(data, BatchIter):
self.data_iterator = data
else:
diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py
index c6390b22..b16f5ddb 100644
--- a/fastNLP/core/trainer.py
+++ b/fastNLP/core/trainer.py
@@ -487,8 +487,8 @@ class Trainer(object):
sampler.set_batch_size(batch_size)
if isinstance(train_data, DataSet):
- self.data_iterator = DataSetIter(
- dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last)
+ self.data_iterator = DataSetIter(dataset=train_data, batch_size=batch_size, sampler=sampler,
+ num_workers=num_workers, drop_last=drop_last)
elif isinstance(train_data, BatchIter):
self.data_iterator = train_data
train_data = train_data.dataset
diff --git a/fastNLP/embeddings/__init__.py b/fastNLP/embeddings/__init__.py
index 3b3b2dce..bf35b7d4 100644
--- a/fastNLP/embeddings/__init__.py
+++ b/fastNLP/embeddings/__init__.py
@@ -12,17 +12,26 @@ __all__ = [
"ElmoEmbedding",
"BertEmbedding",
"BertWordPieceEncoder",
+
+ "RobertaEmbedding",
+ "RobertaWordPieceEncoder",
+
+ "GPT2Embedding",
+ "GPT2WordPieceEncoder",
+
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
"get_embeddings",
+
]
from .embedding import Embedding, TokenEmbedding
from .static_embedding import StaticEmbedding
from .elmo_embedding import ElmoEmbedding
from .bert_embedding import BertEmbedding, BertWordPieceEncoder
-from .roberta_embedding import RobertaEmbedding
+from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder
+from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding
from .stack_embedding import StackEmbedding
from .utils import get_embeddings
diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py
index 3bd448aa..3ad8cd39 100644
--- a/fastNLP/embeddings/bert_embedding.py
+++ b/fastNLP/embeddings/bert_embedding.py
@@ -11,6 +11,7 @@ __all__ = [
import collections
import warnings
from itertools import chain
+from functools import partial
import numpy as np
import torch
@@ -20,7 +21,8 @@ from .contextual_embedding import ContextualEmbedding
from ..core import logger
from ..core.vocabulary import Vocabulary
from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR
-from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
+from ..modules.encoder.bert import BertModel
+from ..modules.tokenizer import BertTokenizer
class BertEmbedding(ContextualEmbedding):
@@ -31,6 +33,7 @@ class BertEmbedding(ContextualEmbedding):
BertEmbedding可以支持自动下载权重,当前支持的模型:
en: base-cased
+ en-base-uncased:
en-large-cased-wwm:
en-large-cased:
en-large-uncased:
@@ -63,7 +66,8 @@ class BertEmbedding(ContextualEmbedding):
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名),
权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。
:param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
- 从0开始,可以以负数去索引倒数几层。
+ 从0开始,可以以负数去索引倒数几层。 layer=0为embedding层(包括wordpiece embedding,
+ position embedding和segment embedding)
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
@@ -80,6 +84,8 @@ class BertEmbedding(ContextualEmbedding):
:param kwargs:
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新
建议设置为True。
+ int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中
+ bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度)
"""
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
@@ -92,25 +98,28 @@ class BertEmbedding(ContextualEmbedding):
" faster speed.")
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
" faster speed.")
-
- self._word_sep_index = None
+
+ self._word_sep_index = -100
if '[SEP]' in vocab:
self._word_sep_index = vocab['[SEP]']
+ self._word_cls_index = -100
+ if '[CLS]' in vocab:
+ self._word_cls_index = vocab['CLS']
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
-
- self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
+ truncate_embed = kwargs.get('truncate_embed', True)
+ min_freq = kwargs.get('min_freq', 2)
+
+ self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
- pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2,
- only_use_pretrain_bpe=only_use_pretrain_bpe)
- self._sep_index = self.model._sep_index
- self._cls_index = self.model._cls_index
+ pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq,
+ only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed)
self.requires_grad = requires_grad
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
-
+
def _delete_model_weights(self):
del self.model
-
+
def forward(self, words):
r"""
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
@@ -125,9 +134,9 @@ class BertEmbedding(ContextualEmbedding):
return self.dropout(outputs)
outputs = self.model(words)
outputs = torch.cat([*outputs], dim=-1)
-
+
return self.dropout(outputs)
-
+
def drop_word(self, words):
r"""
按照设定随机将words设置为unknown_index。
@@ -137,15 +146,16 @@ class BertEmbedding(ContextualEmbedding):
"""
if self.word_dropout > 0 and self.training:
with torch.no_grad():
- not_sep_mask = words.ne(self._sep_index)
- not_cls_mask = words.ne(self._cls_index)
- if self._word_sep_index:
- not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index))
- replaceable_mask = not_sep_mask.__and__(not_cls_mask)
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(0)
- mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk
+ pad_mask = words.ne(self._word_pad_index)
+ mask = pad_mask.__and__(mask) # pad的位置不为unk
+ if self._word_sep_index!=-100:
+ not_sep_mask = words.ne(self._word_sep_index)
+ mask = mask.__and__(not_sep_mask)
+ if self._word_cls_index!=-100:
+ not_cls_mask = words.ne(self._word_cls_index)
+ mask = mask.__and__(not_cls_mask)
words = words.masked_fill(mask, self._word_unk_index)
return words
@@ -167,21 +177,22 @@ class BertWordPieceEncoder(nn.Module):
multi-base-uncased: multilingual uncased
"""
-
+
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool = True):
r"""
-
+
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
- :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
+ :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding,
+ position embedding和segment embedding)
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param bool requires_grad: 是否需要gradient。
"""
super().__init__()
-
- self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
+
+ self.model = _BertWordPieceModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
self._sep_index = self.model._sep_index
self._cls_index = self.model._cls_index
self._wordpiece_pad_index = self.model._wordpiece_pad_index
@@ -190,19 +201,19 @@ class BertWordPieceEncoder(nn.Module):
self.requires_grad = requires_grad
self.word_dropout = word_dropout
self.dropout_layer = nn.Dropout(dropout)
-
+
@property
def embed_size(self):
return self._embed_size
-
+
@property
def embedding_dim(self):
return self._embed_size
-
+
@property
def num_embedding(self):
return self.model.encoder.config.vocab_size
-
+
def index_datasets(self, *datasets, field_name, add_cls_sep=True):
r"""
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
@@ -213,8 +224,8 @@ class BertWordPieceEncoder(nn.Module):
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。
:return:
"""
- self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep)
-
+ self.model.index_datasets(*datasets, field_name=field_name, add_cls_sep=add_cls_sep)
+
def forward(self, word_pieces, token_type_ids=None):
r"""
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
@@ -224,20 +235,20 @@ class BertWordPieceEncoder(nn.Module):
第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
- with torch.no_grad():
- sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
- if token_type_ids is None:
+ if token_type_ids is None:
+ with torch.no_grad():
+ sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()
-
+
word_pieces = self.drop_word(word_pieces)
outputs = self.model(word_pieces, token_type_ids)
outputs = torch.cat([*outputs], dim=-1)
-
+
return self.dropout_layer(outputs)
-
+
def drop_word(self, words):
r"""
按照设定随机将words设置为unknown_index。
@@ -258,38 +269,45 @@ class BertWordPieceEncoder(nn.Module):
return words
-class _WordBertModel(nn.Module):
+class _BertWordModel(nn.Module):
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
- only_use_pretrain_bpe=False):
+ only_use_pretrain_bpe=False, truncate_embed=True):
super().__init__()
-
+
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
- self.layers = list(map(int, layers.split(',')))
+ if isinstance(layers, list):
+ self.layers = [int(l) for l in layers]
+ elif isinstance(layers, str):
+ self.layers = list(map(int, layers.split(',')))
+ else:
+ raise TypeError("`layers` only supports str or list[int]")
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
else:
- assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
-
+
assert pool_method in ('avg', 'max', 'first', 'last')
self.pool_method = pool_method
self.include_cls_sep = include_cls_sep
self.pooled_cls = pooled_cls
self.auto_truncate = auto_truncate
-
+
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
logger.info("Start to generate word pieces for word.")
+ self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
+
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的
- found_count = 0
- self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
+ new_add_to_bpe_vocab = 0
+ unsegment_count = 0
if '[sep]' in vocab:
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.")
if "[CLS]" in vocab:
@@ -311,27 +329,42 @@ class _WordBertModel(nn.Module):
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增
word_piece_dict[word] = 1 # 新增一个值
+ new_add_to_bpe_vocab += 1
+ unsegment_count += 1
continue
for word_piece in word_pieces:
word_piece_dict[word_piece] = 1
- found_count += 1
original_embed = self.encoder.embeddings.word_embeddings.weight.data
+
# 特殊词汇要特殊处理
+ if not truncate_embed:# 如果不删除的话需要将已有的加上
+ word_piece_dict.update(self.tokenzier.vocab)
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed
new_word_piece_vocab = collections.OrderedDict()
+
for index, token in enumerate(['[PAD]', '[UNK]']):
- word_piece_dict.pop(token, None)
- embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]]
- new_word_piece_vocab[token] = index
+ index = word_piece_dict.pop(token, None)
+ if index is not None:
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.vocab[token]]
for token in word_piece_dict.keys():
+ if token not in new_word_piece_vocab:
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ index = new_word_piece_vocab[token]
if token in self.tokenzier.vocab:
- embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab[token]]
+ embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]]
else:
- embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab['[UNK]']]
- new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ embed.weight.data[index] = original_embed[self.tokenzier.vocab['[UNK]']]
+
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab)
self.encoder.embeddings.word_embeddings = embed
-
+ self.encoder.config.vocab_size = len(new_word_piece_vocab)
+ if unsegment_count>0:
+ if only_use_pretrain_bpe or new_add_to_bpe_vocab==0:
+ logger.info(f"{unsegment_count} words are unsegmented.")
+ else:
+ logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.")
+
word_to_wordpieces = []
word_pieces_lengths = []
for word, index in vocab:
@@ -347,11 +380,10 @@ class _WordBertModel(nn.Module):
self._sep_index = self.tokenzier.vocab['[SEP]']
self._word_pad_index = vocab.padding_idx
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
- logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self.word_to_wordpieces = np.array(word_to_wordpieces)
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
logger.debug("Successfully generate word pieces.")
-
+
def forward(self, words):
r"""
@@ -365,8 +397,8 @@ class _WordBertModel(nn.Module):
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
0) # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
- word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
- if word_piece_length + 2 > self._max_position_embeddings:
+ max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
+ if max_word_piece_length + 2 > self._max_position_embeddings:
if self.auto_truncate:
word_pieces_lengths = word_pieces_lengths.masked_fill(
word_pieces_lengths + 2 > self._max_position_embeddings,
@@ -376,9 +408,9 @@ class _WordBertModel(nn.Module):
"After split words into word pieces, the lengths of word pieces are longer than the "
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
-
+
# +2是由于需要加入[CLS]与[SEP]
- word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
+ word_pieces = words.new_full((batch_size, min(max_word_piece_length + 2, self._max_position_embeddings)),
fill_value=self._wordpiece_pad_index)
attn_masks = torch.zeros_like(word_pieces)
# 1. 获取words的word_pieces的id,以及对应的span范围
@@ -406,7 +438,7 @@ class _WordBertModel(nn.Module):
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
-
+
if self.include_cls_sep:
s_shift = 1
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
@@ -421,19 +453,19 @@ class _WordBertModel(nn.Module):
if self.pool_method == 'first':
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
elif self.pool_method == 'last':
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
for l_index, l in enumerate(self.layers):
output_layer = bert_outputs[l]
real_word_piece_length = output_layer.size(1) - 2
- if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
+ if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
paddings = output_layer.new_zeros(batch_size,
- word_piece_length - real_word_piece_length,
+ max_word_piece_length - real_word_piece_length,
output_layer.size(2))
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
# 从word_piece collapse到word的表示
@@ -462,7 +494,85 @@ class _WordBertModel(nn.Module):
outputs[l_index, :, 0] = pooled_cls
else:
outputs[l_index, :, 0] = output_layer[:, 0]
- outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
+ outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift]
# 3. 最终的embedding结果
return outputs
+
+
+class _BertWordPieceModel(nn.Module):
+ r"""
+ 这个模块用于直接计算word_piece的结果.
+
+ """
+
+ def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False):
+ super().__init__()
+
+ self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
+ self.encoder = BertModel.from_pretrained(model_dir_or_name)
+ # 检查encoder_layer_number是否合理
+ encoder_layer_number = len(self.encoder.encoder.layer)
+
+ if isinstance(layers, list):
+ self.layers = [int(l) for l in layers]
+ elif isinstance(layers, str):
+ self.layers = list(map(int, layers.split(',')))
+ else:
+ raise TypeError("`layers` only supports str or list[int]")
+
+ for layer in self.layers:
+ if layer < 0:
+ assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a bert model with {encoder_layer_number} layers."
+ else:
+ assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a bert model with {encoder_layer_number} layers."
+
+ self._cls_index = self.tokenzier.cls_index
+ self._sep_index = self.tokenzier.sep_index
+ self._wordpiece_unknown_index = self.tokenzier.unk_index
+ self._wordpiece_pad_index = self.tokenzier.pad_index # 需要用于生成word_piece
+ self.pooled_cls = pooled_cls
+
+ def index_datasets(self, *datasets, field_name, add_cls_sep=True):
+ r"""
+ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
+ [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
+
+ :param datasets: DataSet对象
+ :param field_name: 基于哪一列index
+ :return:
+ """
+
+ encode_func = partial(self.tokenzier.encode, add_special_tokens=add_cls_sep)
+
+ for index, dataset in enumerate(datasets):
+ try:
+ dataset.apply_field(encode_func, field_name=field_name, new_field_name='word_pieces',
+ is_input=True)
+ dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
+ except Exception as e:
+ logger.error(f"Exception happens when processing the {index} dataset.")
+ raise e
+
+ def forward(self, word_pieces, token_type_ids=None):
+ r"""
+
+ :param word_pieces: torch.LongTensor, batch_size x max_len
+ :param token_type_ids: torch.LongTensor, batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
+ """
+ batch_size, max_len = word_pieces.size()
+
+ attn_masks = word_pieces.ne(self._wordpiece_pad_index)
+ bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
+ output_all_encoded_layers=True)
+ # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
+ outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
+ for l_index, l in enumerate(self.layers):
+ bert_output = bert_outputs[l]
+ if l in (len(bert_outputs)-1, -1) and self.pooled_cls:
+ bert_output[:, 0] = pooled_cls
+ outputs[l_index] = bert_output
+ return outputs
\ No newline at end of file
diff --git a/fastNLP/embeddings/gpt2_embedding.py b/fastNLP/embeddings/gpt2_embedding.py
new file mode 100644
index 00000000..fdae4240
--- /dev/null
+++ b/fastNLP/embeddings/gpt2_embedding.py
@@ -0,0 +1,649 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "GPT2Embedding",
+ "GPT2WordPieceEncoder"
+]
+
+import warnings
+from functools import partial
+from itertools import chain
+from collections import OrderedDict
+
+import torch
+from torch import nn
+import numpy as np
+
+from .contextual_embedding import ContextualEmbedding
+from ..core import logger
+from ..core.utils import _get_model_device
+from ..core.vocabulary import Vocabulary
+from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR
+from ..modules.tokenizer import GPT2Tokenizer
+from ..modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Model
+
+
+class GPT2Embedding(ContextualEmbedding):
+ """
+ 使用GPT2对words进行编码的Embedding。
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import BertEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = GPT2Embedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='4,-2,-1')
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ >>> # torch.Size([1, 5, 3096])
+ """
+
+ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-small', layers: str = '-1',
+ pool_method: str = 'first', dropout=0, requires_grad: bool = True,
+ auto_truncate: bool = False, language_model: bool = False, **kwargs):
+ """
+
+ :param ~fastNLP.Vocabulary vocab: 词表
+ :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名),
+ 权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。
+ :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
+ 从0开始,可以以负数去索引倒数几层。
+ :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
+ 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param bool requires_grad: 是否需要gradient以更新Bert的权重。
+ :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
+ word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
+ 来进行分类的任务将auto_truncate置为True。
+ :param bool language_model: 是否计算gpt2的lm loss,可以通过get_loss()获取,输入一个batch之后的get_loss调用即为batch的language
+ model的loss
+ :param **kwargs:
+ bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新
+ 建议设置为True。
+ int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入GPT2的BPE词表中
+ bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度)
+ """
+ super().__init__(vocab, word_dropout=0, dropout=dropout)
+
+ if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
+ if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'):
+ logger.warning("For Chinese GPT, pooled_method should choose from 'first', 'last' in order to achieve"
+ " faster speed.")
+ warnings.warn("For Chinese GPT, pooled_method should choose from 'first', 'last' in order to achieve"
+ " faster speed.")
+
+ only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
+ truncate_embed = kwargs.get('truncate_embed', True)
+ min_freq = kwargs.get('min_freq', 2)
+
+ self.lm_loss =language_model
+ self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
+ pool_method=pool_method, auto_truncate=auto_truncate, language_model=language_model,
+ only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed,
+ min_freq=min_freq)
+
+ self.requires_grad = requires_grad
+ self._embed_size = len(self.model.layers) * self.model.encoder.config.n_embd
+
+ def _delete_model_weights(self):
+ del self.model
+
+ def forward(self, words):
+ """
+ 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
+ 删除这两个token的表示。
+
+ :param torch.LongTensor words: [batch_size, max_len]
+ :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
+ """
+ outputs = self._get_sent_reprs(words)
+ if outputs is not None:
+ return self.dropout(outputs)
+ outputs = self.model(words)
+ outputs = torch.cat([*outputs], dim=-1)
+
+ return self.dropout(outputs)
+
+ def drop_word(self, words):
+ """
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ with torch.no_grad():
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ words = words.masked_fill(mask, self._word_unk_index)
+ return words
+
+ def get_lm_loss(self, release=True):
+ """
+ 当language_model=True时,可以通过该接口获取当前batch的language model loss的大小
+
+ :param bool release: 如果为True,获取了lm_loss后在下一次forward完成之前都无法获取lm_loss了
+ :return: torch.FloatTensor([])
+ """
+ if hasattr(self.model, '_lm_loss_value'):
+ lm_loss_value = self.model._lm_loss_value
+ if release:
+ delattr(self.model, '_lm_loss_value')
+ return lm_loss_value
+ elif self.lm_loss:
+ raise RuntimeError("Make sure you have passed a batch into GPT2Embdding before accessing loss.")
+ else:
+ raise RuntimeError("Initialize your GPT2Embedding with language_model=True.")
+
+
+class GPT2WordPieceEncoder(nn.Module):
+ """
+ GPT2模型,使用时先使用本模型对应的Tokenizer对数据进行tokenize
+
+ """
+
+ def __init__(self, model_dir_or_name: str = 'en-small', layers: str = '-1',
+ word_dropout=0, dropout=0, requires_grad: bool = True, language_model:bool=False):
+ """
+
+ :param str model_dir_or_name: 模型所在目录或者模型的名称。
+ :param str,list layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
+ :param float word_dropout: 多大概率将word piece置为<|endoftext|>
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param bool language_model: 是否使用language model
+ :param bool requires_grad: 是否需要gradient。
+ """
+ super().__init__()
+
+ self.model = _GPT2WordPieceModel(model_dir_or_name=model_dir_or_name, layers=layers, language_model=language_model)
+ self._wordpiece_pad_index = self.model._wordpiece_pad_index
+ self._embed_size = len(self.model.layers) * self.model.encoder.config.n_embd
+ self.requires_grad = requires_grad
+ self.dropout_layer = nn.Dropout(dropout)
+ self._wordpiece_endoftext_index = self.model._endoftext_index
+ self.word_dropout = word_dropout
+ self.language_model = language_model
+
+ @property
+ def embed_size(self):
+ return self._embed_size
+
+ @property
+ def embedding_dim(self):
+ return self._embed_size
+
+ @property
+ def num_embedding(self):
+ return self.model.encoder.config.vocab_size
+
+ def index_datasets(self, *datasets, field_name, add_endoftext=False, add_prefix_space=True):
+ """
+ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
+ bert的pad value。
+
+ :param ~fastNLP.DataSet datasets: DataSet对象
+ :param list[str] field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
+ :param bool add_endoftext: 在句子开头加入<|endofline|>。
+ :param bool add_prefix_space: 是否在句首增加空格
+ :return:
+ """
+ self.model.index_datasets(*datasets, field_name=field_name, add_endoftext=add_endoftext,
+ add_prefix_space=add_prefix_space)
+
+ def forward(self, word_pieces, token_type_ids=None):
+ """
+ 计算words的bert embedding表示。传入的words中应该在开头包含<|endofline|>。
+
+ :param word_pieces: batch_size x max_len
+ :param token_type_ids: batch_size x max_len,
+ :return: torch.FloatTensor.
+ """
+
+ outputs = self.model(word_pieces)
+ outputs = torch.cat([*outputs], dim=-1)
+
+ return self.dropout_layer(outputs)
+
+ def drop_word(self, words):
+ """
+
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ with torch.no_grad():
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ endoftext_mask = words.ne(self._wordpiece_endoftext_index)
+ mask = endoftext_mask.__and__(mask) # pad的位置不为unk
+ words = words.masked_fill(mask, self._wordpiece_unk_index)
+ return words
+
+ def generate_from_str(self, text='', max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0,
+ repetition_penalty=1.0, length_penalty=1.0):
+ """
+
+ :param str text: 故事的开头
+ :param int max_len: 生成多长的句子
+ :param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。
+ :param int num_beams: 使用多大的beam size
+ :param float temperature: 用以调节采样分布的
+ :param int top_k: 只保留此表中top_k个词进行生成。范围1-infinity
+ :param float top_p: 保留概率累积为top_p的词汇,范围0-1.
+ :param float repetition_penalty: 对重复token的惩罚
+ :param float length_penalty: 惩罚过长的句子
+ :return: list[str]
+ """
+ if len(text)==0:
+ word_pieces = torch.LongTensor([[self.model.tokenizer.bos_index]])
+ start_idx = 1
+ else:
+ assert isinstance(text, str), "Only string input allowed."
+ assert self.language_model, "You must set `language_model=True`."
+ word_pieces = self.model.convert_words_to_word_pieces(text, add_prefix_space=True)
+ word_pieces = torch.LongTensor([word_pieces])
+ start_idx = 0
+ device = _get_model_device(self)
+ word_pieces = word_pieces.to(device)
+ outputs = self.model.encoder.generate(input_ids=word_pieces,
+ max_length=max_len,
+ do_sample=do_sample,
+ num_beams=num_beams,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ bos_token_id=self.model.tokenizer.bos_index,
+ pad_token_id=self.model.tokenizer.eos_index, # 使用<|endoftext|>代替pad
+ eos_token_ids=self.model.tokenizer.eos_index,
+ length_penalty=length_penalty).squeeze(0)
+
+ output_strs = []
+ if outputs.dim()==1:
+ outputs = outputs[None]
+ outputs = outputs[:, start_idx:]
+ for i in range(len(outputs)):
+ str_ = self.model.tokenizer.convert_tokens_to_string(self.model.tokenizer.convert_ids_to_tokens(outputs[i].tolist()))
+ output_strs.append(str_)
+
+ return output_strs
+
+ def generate(self, word_pieces, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0,
+ repetition_penalty=1.0, length_penalty=1.0):
+ """
+
+ :param word_pieces:
+ :param int max_len: 生成多长的句子
+ :param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。
+ :param int num_beams: 使用多大的beam size
+ :param float temperature: 用以调节采样分布的
+ :param int top_k: 只保留此表中top_k个词进行生成。范围1-infinity
+ :param float top_p: 保留概率累积为top_p的词汇,范围0-1.
+ :param float repetition_penalty: 对重复token的惩罚
+ :param float length_penalty: 惩罚过长的句子
+ :return:
+ """
+ pass
+
+ def get_lm_loss(self, release=True):
+ """
+ 当language_model=True时,可以通过该接口获取当前batch的language model loss的大小
+
+ :param bool release: 如果为True,获取了lm_loss后在下一次forward完成之前都无法获取lm_loss了
+ :return: torch.FloatTensor([])
+ """
+ if hasattr(self.model, '_lm_loss_value'):
+ lm_loss_value = self.model._lm_loss_value
+ if release:
+ delattr(self.model, '_lm_loss_value')
+ return lm_loss_value
+ elif self.lm_loss:
+ raise RuntimeError("Make sure you have passed a batch into GPT2Embdding before accessing loss.")
+ else:
+ raise RuntimeError("Initialize your GPT2Embedding with language_model=True.")
+
+
+class _GPT2Model(nn.Module):
+ def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False,
+ only_use_pretrain_bpe=False, min_freq=2, truncate_embed=False):
+ super().__init__()
+
+ self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name)
+ if language_model:
+ self.encoder = GPT2LMHeadModel.from_pretrained(model_dir_or_name)
+ else:
+ self.encoder = GPT2Model.from_pretrained(model_dir_or_name)
+
+ self.lm_loss = language_model
+ self._max_position_embeddings = self.encoder.config.max_position_embeddings
+ # 检查encoder_layer_number是否合理
+ encoder_layer_number = self.encoder.config.n_layer
+ if isinstance(layers, list):
+ self.layers = [int(l) for l in layers]
+ elif isinstance(layers, str):
+ self.layers = list(map(int, layers.split(',')))
+ else:
+ raise TypeError("`layers` only supports str or list[int]")
+ for layer in self.layers:
+ if layer < 0:
+ assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a GPT2 model with {encoder_layer_number} layers."
+ else:
+ assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a GPT2 model with {encoder_layer_number} layers."
+
+ assert pool_method in ('avg', 'max', 'first', 'last')
+ self.pool_method = pool_method
+ self.auto_truncate = auto_truncate
+
+ # 将所有vocab中word的wordpiece计算出来, 需要额外考虑和
+ logger.info("Start to generate word pieces for word.")
+ # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
+ word_piece_dict = {'<|endoftext|>': 1} # 用到的word_piece以及新增的
+ found_count = 0
+ new_add_to_bpe_vocab = 0
+ unsegment_count = 0
+
+ for word, index in vocab:
+ if index == vocab.padding_idx: # pad是个特殊的符号
+ word = '<|endoftext|>'
+ elif index == vocab.unknown_idx:
+ word = '<|endoftext|>'
+ # _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容
+ word_pieces = []
+ word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True))
+ if len(word_pieces) == 1:
+ if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
+ if index not in (vocab.unknown_idx, vocab.padding_idx) and word_pieces[0] == '<|endoftext|>': # 说明这个词不在原始的word里面
+ if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
+ word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增
+ word_piece_dict[word] = 1 # 新增一个值
+ new_add_to_bpe_vocab += 1
+ unsegment_count += 1
+ continue
+ for word_piece in word_pieces:
+ word_piece_dict[word_piece] = 1
+ found_count += 1
+
+ if unsegment_count>0:
+ if only_use_pretrain_bpe or new_add_to_bpe_vocab==0:
+ logger.info(f"{unsegment_count} words are unsegmented.")
+ else:
+ logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.")
+
+ original_embed = self.encoder.get_input_embeddings().weight
+ # 特殊词汇要特殊处理
+ if not truncate_embed: # 如果不删除的话需要将已有的加上
+ word_piece_dict.update(self.tokenzier.encoder)
+
+ embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed
+ new_word_piece_vocab = OrderedDict()
+
+ for index, token in enumerate(['<|endoftext|>']):
+ index = word_piece_dict.pop(token, None)
+ if index is not None:
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]]
+
+ for token in word_piece_dict.keys():
+ if token not in new_word_piece_vocab:
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ index = new_word_piece_vocab[token]
+ if token in self.tokenzier.encoder:
+ embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]]
+ else:
+ embed.weight.data[index] = original_embed[self.tokenzier.encoder['<|endoftext|>']]
+
+ self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab)
+ self.encoder.set_input_embeddings(embed)
+ self.encoder.tie_weights()
+ self.encoder.config.vocab_size = len(new_word_piece_vocab)
+
+ word_to_wordpieces = []
+ word_pieces_lengths = []
+ for word, index in vocab:
+ if index == vocab.padding_idx: # pad是个特殊的符号
+ word = '<|endoftext|>'
+ elif index == vocab.unknown_idx:
+ word = '<|endoftext|>'
+ word_pieces = self.tokenzier.tokenize(word)
+ word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces)
+ word_to_wordpieces.append(word_pieces)
+ word_pieces_lengths.append(len(word_pieces))
+ self._word_pad_index = vocab.padding_idx
+ self._endoftext_index = self.tokenzier.encoder.get('<|endoftext|>')
+ self._wordpiece_pad_index = self.tokenzier.encoder.get('<|endoftext|>') # 需要用于生成word_piece
+ self.word_to_wordpieces = np.array(word_to_wordpieces)
+ self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
+ logger.debug("Successfully generate word pieces.")
+
+ def forward(self, words):
+ """
+
+ :param words: torch.LongTensor, batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
+ """
+ with torch.no_grad():
+ batch_size, max_word_len = words.size()
+ word_mask = words.ne(self._word_pad_index) # 为1的地方有word
+ seq_len = word_mask.sum(dim=-1)
+ batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
+ 0) # batch_size x max_len
+ word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
+ max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
+ if max_word_piece_length > self._max_position_embeddings:
+ if self.auto_truncate:
+ word_pieces_lengths = word_pieces_lengths.masked_fill(
+ word_pieces_lengths > self._max_position_embeddings,
+ self._max_position_embeddings)
+ else:
+ raise RuntimeError(
+ "After split words into word pieces, the lengths of word pieces are longer than the "
+ f"maximum allowed sequence length:{self._max_position_embeddings} of GPT2. You can set "
+ f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
+
+ word_pieces = words.new_full((batch_size, min(max_word_piece_length, self._max_position_embeddings)),
+ fill_value=self._wordpiece_pad_index)
+ word_labels = word_pieces.clone()
+ attn_masks = torch.zeros_like(word_pieces)
+ # 1. 获取words的word_pieces的id,以及对应的span范围
+ word_indexes = words.cpu().numpy()
+ for i in range(batch_size):
+ word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
+ if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings:
+ word_pieces_i = word_pieces_i[:self._max_position_embeddings]
+ word_pieces[i, :word_pieces_lengths[i]] = torch.LongTensor(word_pieces_i)
+ word_labels[i, word_pieces_lengths[i]:].fill_(-100) # 计算lm_loss用的
+ attn_masks[i, :word_pieces_lengths[i]].fill_(1)
+ # 添加<|endoftext|>, 默认不添加了
+ # word_pieces[:, 0].fill_(self._endoftext_index)
+ batch_indexes = torch.arange(batch_size).to(words)
+ # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
+ # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
+ if self.lm_loss:
+ gpt2_outputs = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, labels=word_labels,
+ output_attentions=False)
+ gpt2_outputs, self._lm_loss_value = gpt2_outputs[-1], gpt2_outputs[0] # n_layers x batch_size x max_len x hidden_size
+ else:
+ gpt2_outputs = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,
+ output_attentions=False)[-1]
+ outputs = gpt2_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
+ gpt2_outputs[-1].size(-1))
+
+ batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len+1)
+ batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
+
+ if self.pool_method == 'first':
+ batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
+ _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
+ elif self.pool_method == 'last':
+ batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] - 1
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
+ _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
+
+ for l_index, l in enumerate(self.layers):
+ output_layer = gpt2_outputs[l]
+ real_word_piece_length = output_layer.size(1)
+ if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
+ paddings = output_layer.new_zeros(batch_size,
+ max_word_piece_length - real_word_piece_length,
+ output_layer.size(2))
+ output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
+ # 从word_piece collapse到word的表示
+ # truncate_output_layer = output_layer # 删除endoftext batch_size x len x hidden_size
+ if self.pool_method == 'first':
+ tmp = output_layer[_batch_indexes, batch_word_pieces_cum_length]
+ tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
+ outputs[l_index, :, :batch_word_pieces_cum_length.size(1)] = tmp
+ elif self.pool_method == 'last':
+ tmp = output_layer[_batch_indexes, batch_word_pieces_cum_length]
+ tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
+ outputs[l_index, :, :batch_word_pieces_cum_length.size(1)] = tmp
+ elif self.pool_method == 'max':
+ for i in range(batch_size):
+ for j in range(seq_len[i]):
+ start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
+ outputs[l_index, i, j], _ = torch.max(output_layer[i, start:end], dim=-2)
+ else:
+ for i in range(batch_size):
+ for j in range(seq_len[i]):
+ start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
+ outputs[l_index, i, j] = torch.mean(output_layer[i, start:end], dim=-2)
+
+ # 3. 最终的embedding结果
+ return outputs
+
+ def get_lm_loss(self):
+ """
+ 当language_model为True时,通过该接口可以获取最近传入的一个batch的lanuage model loss
+
+ :return:
+ """
+ return self._lm_loss_value
+
+
+class _GPT2WordPieceModel(nn.Module):
+ """
+ 这个模块用于直接计算word_piece的结果.
+
+ """
+
+ def __init__(self, model_dir_or_name: str, layers: str = '-1', language_model: bool=False):
+ super().__init__()
+
+ self.tokenizer = GPT2Tokenizer.from_pretrained(model_dir_or_name)
+ if language_model:
+ self.encoder = GPT2LMHeadModel.from_pretrained(model_dir_or_name)
+ else:
+ self.encoder = GPT2Model.from_pretrained(model_dir_or_name)
+
+ self.lm_loss = language_model
+
+ # 检查encoder_layer_number是否合理
+ encoder_layer_number = self.encoder.config.n_layer
+
+ if isinstance(layers, list):
+ self.layers = [int(l) for l in layers]
+ elif isinstance(layers, str):
+ self.layers = list(map(int, layers.split(',')))
+ else:
+ raise TypeError("`layers` only supports str or list[int]")
+
+ for layer in self.layers:
+ if layer < 0:
+ assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a gpt2 model with {encoder_layer_number} layers."
+ else:
+ assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a gpt2 model with {encoder_layer_number} layers."
+
+ self._endoftext_index = self.tokenizer.encoder.get('<|endoftext|>')
+ self._wordpiece_pad_index = self.tokenizer.encoder.get('<|endoftext|>') # 原来并没有pad,使用这个值替代一下。这个pad值并不重要,因为是从左到右计算的
+ self._max_position_embeddings = self.encoder.config.max_position_embeddings
+
+ def index_datasets(self, *datasets, field_name, add_endoftext=False, add_prefix_space=True):
+ """
+ 使用gpt2的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果开头不是<|endoftext|>, 且将
+ word_pieces这一列的pad value设置为了bert的pad value。
+
+ :param datasets: DataSet对象
+ :param field_name: 基于哪一列index
+ :param bool add_prefix_space: 是否添加句首的空格
+ :return:
+ """
+ convert_words_to_word_pieces = partial(self.convert_words_to_word_pieces, add_endoftext=add_endoftext,
+ add_prefix_space=add_prefix_space)
+ for index, dataset in enumerate(datasets):
+ try:
+ dataset.apply_field(convert_words_to_word_pieces, field_name=field_name, new_field_name='word_pieces',
+ is_input=True)
+ dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
+ except Exception as e:
+ logger.error(f"Exception happens when processing the {index} dataset.")
+ raise e
+
+ def convert_words_to_word_pieces(self, words, add_endoftext=False, add_prefix_space=True):
+ """
+
+ :param list[str],str words: 将str数据转换为index
+ :param bool add_endoftext: 是否在句首增加endoftext
+ :param bool add_prefix_space: 是否添加句首的空格
+ :return:
+ """
+ word_pieces = []
+ if isinstance(words, str):
+ words = self.tokenizer.tokenize(words, add_prefix_space=add_prefix_space)
+ word_piece_ids = self.tokenizer.convert_tokens_to_ids(words)
+ word_pieces.extend(word_piece_ids)
+ else:
+ for word in words:
+ tokens = self.tokenizer.tokenize(word, add_prefix_space=add_prefix_space)
+ word_piece_ids = self.tokenizer.convert_tokens_to_ids(tokens)
+ word_pieces.extend(word_piece_ids)
+ if add_endoftext:
+ if word_pieces[0] != self._endoftext_index:
+ word_pieces.insert(0, self._endoftext_index)
+ if len(word_pieces) > self._max_position_embeddings:
+ word_pieces[self._max_position_embeddings - 1] = word_pieces[-1]
+ word_pieces = word_pieces[:self._max_position_embeddings]
+ return word_pieces
+
+ def forward(self, word_pieces, token_type_ids=None):
+ """
+
+ :param word_pieces: torch.LongTensor, batch_size x max_len
+ :param token_type_ids: torch.LongTensor, batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
+ """
+ batch_size, max_len = word_pieces.size()
+
+ attn_masks = word_pieces.ne(self._wordpiece_pad_index) # 可能会错误导致开头的词被mask掉
+ word_pieces = word_pieces.masked_fill(attn_masks.eq(0), self._endoftext_index) # 替换pad的值
+ if self.lm_loss:
+ labels = word_pieces.clone()
+ labels = labels.masked_fill(labels.eq(self._wordpiece_pad_index), -100)
+ gpt_outputs = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
+ output_attentions=False, labels=labels)
+ gpt_outputs, self._lm_loss_value = gpt_outputs[-1], gpt_outputs[0] # n_layers x batch_size x max_len x hidden_size
+ else:
+ gpt_outputs = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
+ output_attentions=False)
+ gpt_outputs = gpt_outputs[-1]
+ # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
+ outputs = gpt_outputs[0].new_zeros((len(self.layers), batch_size, max_len, gpt_outputs[0].size(-1)))
+ for l_index, l in enumerate(self.layers):
+ outputs[l_index] = gpt_outputs[l] # 删除开头
+ return outputs
+
+ def get_lm_loss(self):
+ """
+ 当language_model为True时,通过该接口可以获取最近传入的一个batch的lanuage model loss
+
+ :return:
+ """
+ return self._lm_loss_value
+
diff --git a/fastNLP/embeddings/roberta_embedding.py b/fastNLP/embeddings/roberta_embedding.py
index 46b4ebb2..4e77a310 100644
--- a/fastNLP/embeddings/roberta_embedding.py
+++ b/fastNLP/embeddings/roberta_embedding.py
@@ -1,5 +1,10 @@
+r"""
+.. todo::
+ doc
+"""
-import os
+
+from functools import partial
import collections
import warnings
from itertools import chain
@@ -10,7 +15,8 @@ import torch.nn as nn
from .contextual_embedding import ContextualEmbedding
from ..core import logger, Vocabulary
-from ..modules.encoder.roberta import RobertaModel, RobertaTokenizer
+from ..modules.encoder.roberta import RobertaModel
+from ..modules.tokenizer import RobertaTokenizer
class RobertaEmbedding(ContextualEmbedding):
@@ -20,7 +26,8 @@ class RobertaEmbedding(ContextualEmbedding):
时切分),在分割之后长度可能会超过最大长度限制。
RobertaEmbedding可以支持自动下载权重,当前支持的模型:
- ..TODO
+ en: roberta-base
+ en-large: roberta-large
Example::
@@ -43,8 +50,8 @@ class RobertaEmbedding(ContextualEmbedding):
:param ~fastNLP.Vocabulary vocab: 词表
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件
(以vocab.json作为后缀名), 权重文件(以.bin作为文件后缀名), 配置文件(以config.json作为后缀名)。
- :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
- 从0开始,可以以负数去索引倒数几层。
+ :param str,list layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
+ 从0开始,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding, position embedding)
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
@@ -61,24 +68,30 @@ class RobertaEmbedding(ContextualEmbedding):
:param kwargs:
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新
建议设置为True。
+ int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中
+ bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度)
"""
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if word_dropout > 0:
assert vocab.unknown is not None, "When word_drop > 0, Vocabulary must contain the unknown token."
- self._word_sep_index = None
+ self._word_sep_index = -100
if '' in vocab:
self._word_sep_index = vocab['']
+ self._word_cls_index = -100
+ if '' in vocab:
+ self._word_cls_index = vocab['']
+
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
+ truncate_embed = kwargs.get('truncate_embed', True)
+ min_freq = kwargs.get('min_freq', 2)
- self.model = _WordRobertaModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
+ self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
- pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2,
- only_use_pretrain_bpe=only_use_pretrain_bpe)
- self._sep_index = self.model._sep_index
- self._cls_index = self.model._cls_index
+ pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq,
+ only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed)
self.requires_grad = requires_grad
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
@@ -111,37 +124,46 @@ class RobertaEmbedding(ContextualEmbedding):
"""
if self.word_dropout > 0 and self.training:
with torch.no_grad():
- not_sep_mask = words.ne(self._sep_index)
- not_cls_mask = words.ne(self._cls_index)
- if self._word_sep_index:
- not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index))
- replaceable_mask = not_sep_mask.__and__(not_cls_mask)
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(self._word_pad_index)
- mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk
+ mask = pad_mask.__and__(mask) # pad的位置不为unk
+ if self._word_sep_index!=-100:
+ not_sep_mask = words.ne(self._word_sep_index)
+ mask = mask.__and__(not_sep_mask)
+ if self._word_cls_index!=-100:
+ not_cls_mask = words.ne(self._word_cls_index)
+ mask = mask.__and__(not_cls_mask)
words = words.masked_fill(mask, self._word_unk_index)
return words
-class _WordRobertaModel(nn.Module):
+class _RobertaWordModel(nn.Module):
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
- only_use_pretrain_bpe=False):
+ only_use_pretrain_bpe=False, truncate_embed=True):
super().__init__()
self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name)
self.encoder = RobertaModel.from_pretrained(model_dir_or_name)
- self._max_position_embeddings = self.encoder.config.max_position_embeddings
+ # 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
+ self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
- self.layers = list(map(int, layers.split(',')))
+
+ if isinstance(layers, list):
+ self.layers = [int(l) for l in layers]
+ elif isinstance(layers, str):
+ self.layers = list(map(int, layers.split(',')))
+ else:
+ raise TypeError("`layers` only supports str or list[int]")
+
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a roberta model with {encoder_layer_number} layers."
else:
- assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a roberta model with {encoder_layer_number} layers."
assert pool_method in ('avg', 'max', 'first', 'last')
@@ -155,7 +177,8 @@ class _WordRobertaModel(nn.Module):
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
word_piece_dict = {'': 1, '': 1} # 用到的word_piece以及新增的
found_count = 0
- self._has_sep_in_vocab = '' in vocab # 用来判断传入的数据是否需要生成token_ids
+ new_add_to_bpe_vocab = 0
+ unsegment_count = 0
if "" in vocab:
warnings.warn(" detected in your vocabulary. RobertaEmbedding will add and to the begin "
"and end of the input automatically, make sure you don't add and at the begin"
@@ -167,33 +190,53 @@ class _WordRobertaModel(nn.Module):
word = ''
# _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容
word_pieces = []
- word_pieces.extend(self.tokenzier.tokenize(word))
+ # 如果这个word不是在句子开头
+ word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True))
if len(word_pieces) == 1:
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
if index != vocab.unknown_idx and word_pieces[0] == '': # 说明这个词不在原始的word里面
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增
word_piece_dict[word] = 1 # 新增一个值
+ new_add_to_bpe_vocab += 1
+ unsegment_count += 1
continue
+ found_count += 1
for word_piece in word_pieces:
word_piece_dict[word_piece] = 1
- found_count += 1
+ # 如果这个word是在句子开头
+
original_embed = self.encoder.embeddings.word_embeddings.weight.data
# 特殊词汇要特殊处理
+ if not truncate_embed: # 如果不删除的话需要将已有的加上
+ word_piece_dict.update(self.tokenzier.encoder)
+
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed
new_word_piece_vocab = collections.OrderedDict()
- for index, token in enumerate(['', '']):
- word_piece_dict.pop(token, None)
- embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]]
- new_word_piece_vocab[token] = index
+
+ for index, token in enumerate(['', '', '', '']):
+ index = word_piece_dict.pop(token, None)
+ if index is not None:
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]]
for token in word_piece_dict.keys():
+ if token not in new_word_piece_vocab:
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ index = new_word_piece_vocab[token]
if token in self.tokenzier.encoder:
- embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.encoder[token]]
+ embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]]
else:
- embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.encoder['']]
- new_word_piece_vocab[token] = len(new_word_piece_vocab)
- self._reinit_on_new_vocab(new_word_piece_vocab, model_dir_or_name)
+ embed.weight.data[index] = original_embed[self.tokenzier.encoder['']]
+
+ self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab)
self.encoder.embeddings.word_embeddings = embed
+ self.encoder.config.vocab_size = len(new_word_piece_vocab)
+
+ if unsegment_count>0:
+ if only_use_pretrain_bpe or new_add_to_bpe_vocab==0:
+ logger.info(f"{unsegment_count} words are unsegmented.")
+ else:
+ logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.")
word_to_wordpieces = []
word_pieces_lengths = []
@@ -210,18 +253,10 @@ class _WordRobertaModel(nn.Module):
self._sep_index = self.tokenzier.encoder['']
self._word_pad_index = vocab.padding_idx
self._wordpiece_pad_index = self.tokenzier.encoder[''] # 需要用于生成word_piece
- logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self.word_to_wordpieces = np.array(word_to_wordpieces)
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
logger.debug("Successfully generate word pieces.")
- def _reinit_on_new_vocab(self, vocab, model_dir_or_name):
- import json
- with open('./.tmp-new-vocab-file.json', 'w') as f:
- json.dump(vocab, f)
- self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name, vocab_file='./.tmp-new-vocab-file.json')
- os.remove('./.tmp-new-vocab-file.json')
-
def forward(self, words):
r"""
@@ -232,15 +267,13 @@ class _WordRobertaModel(nn.Module):
batch_size, max_word_len = words.size()
word_mask = words.ne(self._word_pad_index) # 为1的地方有word
seq_len = word_mask.sum(dim=-1)
- batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
- 0) # batch_size x max_len
+ batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), 0) # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
- word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
- if word_piece_length + 2 > self._max_position_embeddings:
+ max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
+ if max_word_piece_length + 2 > self._max_position_embeddings:
if self.auto_truncate:
word_pieces_lengths = word_pieces_lengths.masked_fill(
- word_pieces_lengths + 2 > self._max_position_embeddings,
- self._max_position_embeddings - 2)
+ word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2)
else:
raise RuntimeError(
"After split words into word pieces, the lengths of word pieces are longer than the "
@@ -248,7 +281,7 @@ class _WordRobertaModel(nn.Module):
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
# +2是由于需要加入与
- word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
+ word_pieces = words.new_full((batch_size, min(max_word_piece_length + 2, self._max_position_embeddings)),
fill_value=self._wordpiece_pad_index)
attn_masks = torch.zeros_like(word_pieces)
# 1. 获取words的word_pieces的id,以及对应的span范围
@@ -259,17 +292,9 @@ class _WordRobertaModel(nn.Module):
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
- # 添加[cls]和[sep]
word_pieces[:, 0].fill_(self._cls_index)
batch_indexes = torch.arange(batch_size).to(words)
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
- # if self._has_sep_in_vocab: # 但在vocab中出现应该才会需要token_ids
- # sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len
- # sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
- # token_type_ids = sep_mask_cumsum.fmod(2)
- # if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
- # token_type_ids = token_type_ids.eq(0).long()
- # else: # RoBERTa不需要额外设置token_type_ids
token_type_ids = torch.zeros_like(word_pieces)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
@@ -292,19 +317,19 @@ class _WordRobertaModel(nn.Module):
if self.pool_method == 'first':
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
elif self.pool_method == 'last':
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max() + 1] - 1
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
+ batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
for l_index, l in enumerate(self.layers):
output_layer = bert_outputs[l]
real_word_piece_length = output_layer.size(1) - 2
- if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
+ if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
paddings = output_layer.new_zeros(batch_size,
- word_piece_length - real_word_piece_length,
+ max_word_piece_length - real_word_piece_length,
output_layer.size(2))
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
# 从word_piece collapse到word的表示
@@ -333,7 +358,176 @@ class _WordRobertaModel(nn.Module):
outputs[l_index, :, 0] = pooled_cls
else:
outputs[l_index, :, 0] = output_layer[:, 0]
- outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
+ outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift]
# 3. 最终的embedding结果
return outputs
+
+
+class RobertaWordPieceEncoder(nn.Module):
+ r"""
+ 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。
+
+ BertWordPieceEncoder可以支持自动下载权重,当前支持的模型:
+ en: roberta-base
+ en-large: roberta-large
+
+ """
+
+ def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
+ word_dropout=0, dropout=0, requires_grad: bool = True):
+ r"""
+
+ :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
+ :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding,
+ position embedding)
+ :param bool pooled_cls: 返回的句子开头的是否使用预训练中的BertPool映射一下。如果下游任务取做预测,一般该值为True。
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param bool requires_grad: 是否需要gradient。
+ """
+ super().__init__()
+
+ self.model = _WordPieceRobertaModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
+ self._sep_index = self.model._sep_index
+ self._cls_index = self.model._cls_index
+ self._wordpiece_pad_index = self.model._wordpiece_pad_index
+ self._wordpiece_unk_index = self.model._wordpiece_unknown_index
+ self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
+ self.requires_grad = requires_grad
+ self.word_dropout = word_dropout
+ self.dropout_layer = nn.Dropout(dropout)
+
+ @property
+ def embed_size(self):
+ return self._embed_size
+
+ @property
+ def embedding_dim(self):
+ return self._embed_size
+
+ @property
+ def num_embedding(self):
+ return self.model.encoder.config.vocab_size
+
+ def index_datasets(self, *datasets, field_name, add_cls_sep=True, add_prefix_space=True):
+ r"""
+ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
+ bert的pad value。
+
+ :param ~fastNLP.DataSet datasets: DataSet对象
+ :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
+ :param bool add_cls_sep: 如果首尾不是与会在首尾额外加入与。
+ :param bool add_prefix_spance: 是否在句首添加额外的空格,RoBERTa预训练时该值为True
+ :return:
+ """
+ self.model.index_datasets(*datasets, field_name=field_name, add_cls_sep=add_cls_sep, add_prefix_space=add_prefix_space)
+
+ def forward(self, word_pieces, token_type_ids=None):
+ r"""
+ 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
+
+ :param words: batch_size x max_len
+ :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入),
+ 第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。
+ :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
+ """
+ word_pieces = self.drop_word(word_pieces)
+ outputs = self.model(word_pieces)
+ outputs = torch.cat([*outputs], dim=-1)
+
+ return self.dropout_layer(outputs)
+
+ def drop_word(self, words):
+ r"""
+ 按照设定随机将words设置为unknown_index。
+
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ with torch.no_grad():
+ not_sep_mask = words.ne(self._sep_index)
+ not_cls_mask = words.ne(self._cls_index)
+ replaceable_mask = not_sep_mask.__and__(not_cls_mask)
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ pad_mask = words.ne(self._wordpiece_pad_index)
+ mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk
+ words = words.masked_fill(mask, self._wordpiece_unk_index)
+ return words
+
+
+class _WordPieceRobertaModel(nn.Module):
+ def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False):
+ super().__init__()
+
+ self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name)
+ self.encoder = RobertaModel.from_pretrained(model_dir_or_name)
+ # 检查encoder_layer_number是否合理
+ encoder_layer_number = len(self.encoder.encoder.layer)
+
+ if isinstance(layers, list):
+ self.layers = [int(l) for l in layers]
+ elif isinstance(layers, str):
+ self.layers = list(map(int, layers.split(',')))
+ else:
+ raise TypeError("`layers` only supports str or list[int]")
+
+ for layer in self.layers:
+ if layer < 0:
+ assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a RoBERTa model with {encoder_layer_number} layers."
+ else:
+ assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a RoBERTa model with {encoder_layer_number} layers."
+
+ self._cls_index = self.tokenzier.encoder['']
+ self._sep_index = self.tokenzier.encoder['']
+ self._wordpiece_pad_index = self.tokenzier.encoder[''] # 需要用于生成word_piece
+ self._wordpiece_unknown_index = self.tokenzier.encoder['']
+ self.pooled_cls = pooled_cls
+
+ def index_datasets(self, *datasets, field_name, add_cls_sep=True, add_prefix_space=True):
+ r"""
+ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
+ [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
+
+ :param datasets: DataSet对象
+ :param field_name: 基于哪一列index
+ :param bool add_cls_sep: 是否在句首句尾添加cls和sep的index
+ :param bool add_prefix_space: 是否在句子开头添加空格,预训练时RoBERTa该值为True
+ :return:
+ """
+
+ encode_func = partial(self.tokenzier.encode, add_special_tokens=add_cls_sep, add_prefix_space=add_prefix_space)
+
+ for index, dataset in enumerate(datasets):
+ try:
+ dataset.apply_field(encode_func, field_name=field_name, new_field_name='word_pieces',
+ is_input=True)
+ dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
+ except Exception as e:
+ logger.error(f"Exception happens when processing the {index} dataset.")
+ raise e
+
+ def forward(self, word_pieces):
+ r"""
+
+ :param word_pieces: torch.LongTensor, batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
+ """
+ batch_size, max_len = word_pieces.size()
+
+ attn_masks = word_pieces.ne(self._wordpiece_pad_index)
+ roberta_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=torch.zeros_like(word_pieces),
+ attention_mask=attn_masks,
+ output_all_encoded_layers=True)
+ # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
+ outputs = roberta_outputs[0].new_zeros((len(self.layers), batch_size, max_len, roberta_outputs[0].size(-1)))
+ for l_index, l in enumerate(self.layers):
+ roberta_output = roberta_outputs[l]
+ if l in (len(roberta_output)-1, -1) and self.pooled_cls:
+ roberta_output[:, 0] = pooled_cls
+ outputs[l_index] = roberta_output
+ return outputs
\ No newline at end of file
diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py
index fe697699..96a9c1ed 100644
--- a/fastNLP/io/file_utils.py
+++ b/fastNLP/io/file_utils.py
@@ -48,6 +48,18 @@ PRETRAINED_BERT_MODEL_DIR = {
'cn-wwm-ext': "bert-chinese-wwm-ext.zip"
}
+PRETRAINED_GPT2_MODEL_DIR = {
+ 'en': 'gpt2.zip',
+ 'en-medium': 'gpt2-medium.zip',
+ 'en-large': 'gpt2-large.zip',
+ 'en-xl': 'gpt2-xl.zip'
+}
+
+PRETRAINED_ROBERTA_MODEL_DIR = {
+ 'en': 'roberta-base.zip',
+ 'en-large': 'roberta-large.zip'
+}
+
PRETRAINED_ELMO_MODEL_DIR = {
'en': 'elmo_en_Medium.zip',
'en-small': "elmo_en_Small.zip",
@@ -127,14 +139,18 @@ DATASET_DIR = {
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
"bert": PRETRAINED_BERT_MODEL_DIR,
- "static": PRETRAIN_STATIC_FILES}
+ "static": PRETRAIN_STATIC_FILES,
+ 'gpt2': PRETRAINED_GPT2_MODEL_DIR,
+ 'roberta': PRETRAINED_ROBERTA_MODEL_DIR}
# 用于扩展fastNLP的下载
FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt'
FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt',
- 'bert':'fastnlp_bert_url.txt',
- 'static': 'fastnlp_static_url.txt'
-}
+ 'bert':'fastnlp_bert_url.txt',
+ 'static': 'fastnlp_static_url.txt',
+ 'gpt2': 'fastnlp_gpt2_url.txt',
+ 'roberta': 'fastnlp_roberta_url.txt'
+ }
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path:
@@ -273,7 +289,7 @@ def _get_embedding_url(embed_type, name):
return url
raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys())))
else:
- raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static")
+ raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static, gpt2, roberta")
def _read_extend_url_file(filename, name)->str:
r"""
@@ -281,7 +297,7 @@ def _read_extend_url_file(filename, name)->str:
:param str filename: 在默认的路径下寻找file这个文件
:param str name: 需要寻找的资源的名称
- :return: str or None
+ :return: str,None
"""
cache_dir = get_cache_path()
filepath = os.path.join(cache_dir, filename)
@@ -488,3 +504,42 @@ def match_file(dir_name: str, cache_dir: Path) -> str:
return matched_filenames[-1]
else:
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.")
+
+
+def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'):
+ if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
+ model_url = _get_embedding_url('bert', model_dir_or_name.lower())
+ model_dir = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
+ else:
+ logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
+ raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
+ return str(model_dir)
+
+
+def _get_gpt2_dir(model_dir_or_name: str = 'en'):
+ if model_dir_or_name.lower() in PRETRAINED_GPT2_MODEL_DIR:
+ model_url = _get_embedding_url('gpt2', model_dir_or_name.lower())
+ model_dir = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
+ else:
+ logger.error(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.")
+ raise ValueError(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.")
+ return str(model_dir)
+
+
+def _get_roberta_dir(model_dir_or_name: str = 'en'):
+ if model_dir_or_name.lower() in PRETRAINED_ROBERTA_MODEL_DIR:
+ model_url = _get_embedding_url('roberta', model_dir_or_name.lower())
+ model_dir = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
+ else:
+ logger.error(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.")
+ raise ValueError(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.")
+ return str(model_dir)
diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py
index 53651b59..d8eab276 100644
--- a/fastNLP/modules/__init__.py
+++ b/fastNLP/modules/__init__.py
@@ -49,7 +49,15 @@ __all__ = [
"TimestepDropout",
- 'summary'
+ 'summary',
+
+ "BertTokenizer",
+ "BertModel",
+
+ "RobertaTokenizer",
+ "RobertaModel",
+
+ "GPT2Tokenizer"
]
import sys
@@ -61,5 +69,6 @@ from .dropout import TimestepDropout
from .encoder import *
from .utils import summary
from ..doc_utils import doc_process
+from .tokenizer import *
doc_process(sys.modules[__name__])
diff --git a/fastNLP/modules/decoder/seq2seq_decoder.py b/fastNLP/modules/decoder/seq2seq_decoder.py
new file mode 100755
index 00000000..3933867a
--- /dev/null
+++ b/fastNLP/modules/decoder/seq2seq_decoder.py
@@ -0,0 +1,109 @@
+# coding=utf-8
+__all__ = [
+ "TransformerPast",
+ "Past",
+ "Decoder"
+]
+import torch
+from torch import nn
+import abc
+import torch.nn.functional as F
+from ...embeddings import StaticEmbedding
+import numpy as np
+from typing import Union, Tuple
+from ...embeddings.utils import get_embeddings
+from torch.nn import LayerNorm
+import math
+
+
+class Past:
+ def __init__(self):
+ pass
+
+ @abc.abstractmethod
+ def num_samples(self):
+ pass
+
+ @abc.abstractmethod
+ def reorder_past(self, indices: torch.LongTensor):
+ """
+ 根据indices中的index,将past的中状态置为正确的顺序。inplace改变
+
+ :param torch.LongTensor indices:
+ :param Past past:
+ :return:
+ """
+ raise NotImplemented
+
+
+class TransformerPast(Past):
+ def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None,
+ num_decoder_layer: int = 6):
+ """
+
+ :param encoder_outputs: (batch,src_seq_len,dim)
+ :param encoder_mask: (batch,src_seq_len)
+ :param encoder_key: list of (batch, src_seq_len, dim)
+ :param encoder_value:
+ :param decoder_prev_key:
+ :param decoder_prev_value:
+ """
+ super().__init__()
+ self.encoder_outputs = encoder_outputs
+ self.encoder_mask = encoder_mask
+ self.encoder_key = [None] * num_decoder_layer
+ self.encoder_value = [None] * num_decoder_layer
+ self.decoder_prev_key = [None] * num_decoder_layer
+ self.decoder_prev_value = [None] * num_decoder_layer
+
+ def num_samples(self):
+ if self.encoder_outputs is not None:
+ return self.encoder_outputs.size(0)
+ return None
+
+ def _reorder_state(self, state, indices):
+ if type(state) == torch.Tensor:
+ state = state.index_select(index=indices, dim=0)
+ elif type(state) == list:
+ for i in range(len(state)):
+ assert state[i] is not None
+ state[i] = state[i].index_select(index=indices, dim=0)
+ else:
+ raise ValueError('State does not support other format')
+
+ return state
+
+ def reorder_past(self, indices: torch.LongTensor):
+ self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices)
+ self.encoder_mask = self._reorder_state(self.encoder_mask, indices)
+ self.encoder_key = self._reorder_state(self.encoder_key, indices)
+ self.encoder_value = self._reorder_state(self.encoder_value, indices)
+ self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices)
+ self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices)
+ return self
+
+
+class Decoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ @abc.abstractmethod
+ def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]:
+ """
+ 当模型进行解码时,使用这个函数。返回一个batch_size x vocab_size的结果与更新的Past状态。需要考虑一种特殊情况,即tokens长度不是1,即给定了
+ 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态。
+
+ :return: tensor:batch_size x vocab_size, past: Past
+ """
+ raise NotImplemented
+
+ @abc.abstractmethod
+ def reorder_past(self, indices: torch.LongTensor, past: Past):
+ """
+ 根据indices中的index,将past的中状态置为正确的顺序。inplace改变
+
+ :param torch.LongTensor indices:
+ :param Past past:
+ :return:
+ """
+ raise NotImplemented
\ No newline at end of file
diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py
index 3c9af22d..fccb2c00 100644
--- a/fastNLP/modules/encoder/__init__.py
+++ b/fastNLP/modules/encoder/__init__.py
@@ -30,6 +30,10 @@ __all__ = [
"MultiHeadAttention",
"BiAttention",
"SelfAttention",
+
+ "BertModel",
+
+ "RobertaModel",
]
from .attention import MultiHeadAttention, BiAttention, SelfAttention
diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py
index 32edafbe..bfa1c6a1 100644
--- a/fastNLP/modules/encoder/bert.py
+++ b/fastNLP/modules/encoder/bert.py
@@ -4,26 +4,23 @@ r"""undocumented
"""
__all__ = [
- "BertModel"
+ "BertModel",
]
-import collections
import copy
import json
import math
-import os
-import unicodedata
import torch
from torch import nn
import numpy as np
from ..utils import _get_file_name_base_on_postfix
-from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
+from ...io.file_utils import _get_bert_dir
from ...core import logger
+
CONFIG_FILE = 'bert_config.json'
-VOCAB_NAME = 'vocab.txt'
BERT_KEY_RENAME_MAP_1 = {
'gamma': 'weight',
@@ -152,33 +149,22 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
-def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'):
- if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
- model_url = _get_embedding_url('bert', model_dir_or_name.lower())
- model_dir = cached_path(model_url, name='embedding')
- # 检查是否存在
- elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
- model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
- else:
- logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
- raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
- return str(model_dir)
-
-
-class BertLayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-12):
- r"""Construct a layernorm module in the TF style (epsilon inside the square root).
- """
- super(BertLayerNorm, self).__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.bias = nn.Parameter(torch.zeros(hidden_size))
- self.variance_epsilon = eps
+# class BertLayerNorm(nn.Module):
+# def __init__(self, hidden_size, eps=1e-12):
+# r"""Construct a layernorm module in the TF style (epsilon inside the square root).
+# """
+# super(BertLayerNorm, self).__init__()
+# self.weight = nn.Parameter(torch.ones(hidden_size))
+# self.bias = nn.Parameter(torch.zeros(hidden_size))
+# self.variance_epsilon = eps
+#
+# def forward(self, x):
+# u = x.mean(-1, keepdim=True)
+# s = (x - u).pow(2).mean(-1, keepdim=True)
+# x = (x - u) / torch.sqrt(s + self.variance_epsilon)
+# return self.weight * x + self.bias
- def forward(self, x):
- u = x.mean(-1, keepdim=True)
- s = (x - u).pow(2).mean(-1, keepdim=True)
- x = (x - u) / torch.sqrt(s + self.variance_epsilon)
- return self.weight * x + self.bias
+BertLayerNorm = torch.nn.LayerNorm
class DistilBertEmbeddings(nn.Module):
@@ -518,6 +504,7 @@ class BertModel(nn.Module):
pooled_output = sequence_output[:, 0]
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
+ encoded_layers.insert(0, embedding_output)
return encoded_layers, pooled_output
@classmethod
@@ -615,435 +602,3 @@ class BertModel(nn.Module):
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.")
return model
-
-def whitespace_tokenize(text):
- r"""Runs basic whitespace cleaning and splitting on a piece of text."""
- text = text.strip()
- if not text:
- return []
- tokens = text.split()
- return tokens
-
-
-class WordpieceTokenizer(object):
- r"""Runs WordPiece tokenization."""
-
- def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
- self.vocab = vocab
- self.unk_token = unk_token
- self.max_input_chars_per_word = max_input_chars_per_word
-
- def tokenize(self, text):
- r"""Tokenizes a piece of text into its word pieces.
-
- This uses a greedy longest-match-first algorithm to perform tokenization
- using the given vocabulary.
-
- For example:
- input = "unaffable"
- output = ["un", "##aff", "##able"]
-
- Args:
- text: A single token or whitespace separated tokens. This should have
- already been passed through `BasicTokenizer`.
-
- Returns:
- A list of wordpiece tokens.
- """
-
- output_tokens = []
- for token in whitespace_tokenize(text):
- chars = list(token)
- if len(chars) > self.max_input_chars_per_word:
- output_tokens.append(self.unk_token)
- continue
-
- is_bad = False
- start = 0
- sub_tokens = []
- while start < len(chars):
- end = len(chars)
- cur_substr = None
- while start < end:
- substr = "".join(chars[start:end])
- if start > 0:
- substr = "##" + substr
- if substr in self.vocab:
- cur_substr = substr
- break
- end -= 1
- if cur_substr is None:
- is_bad = True
- break
- sub_tokens.append(cur_substr)
- start = end
-
- if is_bad:
- output_tokens.append(self.unk_token)
- else:
- output_tokens.extend(sub_tokens)
- if len(output_tokens) == 0: # 防止里面全是空格或者回车符号
- return [self.unk_token]
- return output_tokens
-
-
-def load_vocab(vocab_file):
- r"""Loads a vocabulary file into a dictionary."""
- vocab = collections.OrderedDict()
- index = 0
- with open(vocab_file, "r", encoding="utf-8") as reader:
- while True:
- token = reader.readline()
- if not token:
- break
- token = token.strip()
- vocab[token] = index
- index += 1
- return vocab
-
-
-class BasicTokenizer(object):
- r"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
-
- def __init__(self,
- do_lower_case=True,
- never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
- r"""Constructs a BasicTokenizer.
-
- Args:
- do_lower_case: Whether to lower case the input.
- """
- self.do_lower_case = do_lower_case
- self.never_split = never_split
-
- def tokenize(self, text):
- r"""Tokenizes a piece of text."""
- text = self._clean_text(text)
- # This was added on November 1st, 2018 for the multilingual and Chinese
- # models. This is also applied to the English models now, but it doesn't
- # matter since the English models were not trained on any Chinese data
- # and generally don't have any Chinese data in them (there are Chinese
- # characters in the vocabulary because Wikipedia does have some Chinese
- # words in the English Wikipedia.).
- text = self._tokenize_chinese_chars(text)
- orig_tokens = whitespace_tokenize(text)
- split_tokens = []
- for token in orig_tokens:
- if self.do_lower_case and token not in self.never_split:
- token = token.lower()
- token = self._run_strip_accents(token)
- split_tokens.extend(self._run_split_on_punc(token))
-
- output_tokens = whitespace_tokenize(" ".join(split_tokens))
- return output_tokens
-
- def _run_strip_accents(self, text):
- r"""Strips accents from a piece of text."""
- text = unicodedata.normalize("NFD", text)
- output = []
- for char in text:
- cat = unicodedata.category(char)
- if cat == "Mn":
- continue
- output.append(char)
- return "".join(output)
-
- def _run_split_on_punc(self, text):
- r"""Splits punctuation on a piece of text."""
- if text in self.never_split:
- return [text]
- chars = list(text)
- i = 0
- start_new_word = True
- output = []
- while i < len(chars):
- char = chars[i]
- if _is_punctuation(char):
- output.append([char])
- start_new_word = True
- else:
- if start_new_word:
- output.append([])
- start_new_word = False
- output[-1].append(char)
- i += 1
-
- return ["".join(x) for x in output]
-
- def _tokenize_chinese_chars(self, text):
- r"""Adds whitespace around any CJK character."""
- output = []
- for char in text:
- cp = ord(char)
- if self._is_chinese_char(cp):
- output.append(" ")
- output.append(char)
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
-
- def _is_chinese_char(self, cp):
- r"""Checks whether CP is the codepoint of a CJK character."""
- # This defines a "chinese character" as anything in the CJK Unicode block:
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
- #
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
- # despite its name. The modern Korean Hangul alphabet is a different block,
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
- # space-separated words, so they are not treated specially and handled
- # like the all of the other languages.
- if (((cp >= 0x4E00) and (cp <= 0x9FFF)) or #
- ((cp >= 0x3400) and (cp <= 0x4DBF)) or #
- ((cp >= 0x20000) and (cp <= 0x2A6DF)) or #
- ((cp >= 0x2A700) and (cp <= 0x2B73F)) or #
- ((cp >= 0x2B740) and (cp <= 0x2B81F)) or #
- ((cp >= 0x2B820) and (cp <= 0x2CEAF)) or
- ((cp >= 0xF900) and (cp <= 0xFAFF)) or #
- ((cp >= 0x2F800) and (cp <= 0x2FA1F))): #
- return True
-
- return False
-
- def _clean_text(self, text):
- r"""Performs invalid character removal and whitespace cleanup on text."""
- output = []
- for char in text:
- cp = ord(char)
- if cp == 0 or cp == 0xfffd or _is_control(char):
- continue
- if _is_whitespace(char):
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
-
-
-def _is_whitespace(char):
- r"""Checks whether `chars` is a whitespace character."""
- # \t, \n, and \r are technically contorl characters but we treat them
- # as whitespace since they are generally considered as such.
- if char == " " or char == "\t" or char == "\n" or char == "\r":
- return True
- cat = unicodedata.category(char)
- if cat == "Zs":
- return True
- return False
-
-
-def _is_control(char):
- r"""Checks whether `chars` is a control character."""
- # These are technically control characters but we count them as whitespace
- # characters.
- if char == "\t" or char == "\n" or char == "\r":
- return False
- cat = unicodedata.category(char)
- if cat.startswith("C"):
- return True
- return False
-
-
-def _is_punctuation(char):
- r"""Checks whether `chars` is a punctuation character."""
- cp = ord(char)
- # We treat all non-letter/number ASCII as punctuation.
- # Characters such as "^", "$", and "`" are not in the Unicode
- # Punctuation class but we treat them as punctuation anyways, for
- # consistency.
- if (((cp >= 33) and (cp <= 47)) or ((cp >= 58) and (cp <= 64)) or
- ((cp >= 91) and (cp <= 96)) or ((cp >= 123) and (cp <= 126))):
- return True
- cat = unicodedata.category(char)
- if cat.startswith("P"):
- return True
- return False
-
-
-class BertTokenizer(object):
- r"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
-
- def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
- never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
- r"""Constructs a BertTokenizer.
-
- Args:
- vocab_file: Path to a one-wordpiece-per-line vocabulary file
- do_lower_case: Whether to lower case the input
- Only has an effect when do_wordpiece_only=False
- do_basic_tokenize: Whether to do basic tokenization before wordpiece.
- max_len: An artificial maximum length to truncate tokenized sequences to;
- Effective maximum length is always the minimum of this
- value (if specified) and the underlying BERT model's
- sequence length.
- never_split: List of tokens which will never be split during tokenization.
- Only has an effect when do_wordpiece_only=False
- """
- if not os.path.isfile(vocab_file):
- raise ValueError(
- "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
- self.vocab = load_vocab(vocab_file)
- self.ids_to_tokens = collections.OrderedDict(
- [(ids, tok) for tok, ids in self.vocab.items()])
- self.do_basic_tokenize = do_basic_tokenize
- if do_basic_tokenize:
- self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
- never_split=never_split)
- self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
- self.max_len = max_len if max_len is not None else int(1e12)
-
- def _reinit_on_new_vocab(self, vocab):
- r"""
- 在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质
-
- :param vocab:
- :return:
- """
- self.vocab = vocab
- self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
-
- def tokenize(self, text):
- split_tokens = []
- if self.do_basic_tokenize:
- for token in self.basic_tokenizer.tokenize(text):
- for sub_token in self.wordpiece_tokenizer.tokenize(token):
- split_tokens.append(sub_token)
- else:
- split_tokens = self.wordpiece_tokenizer.tokenize(text)
- return split_tokens
-
- def convert_tokens_to_ids(self, tokens):
- r"""Converts a sequence of tokens into ids using the vocab."""
- ids = []
- for token in tokens:
- ids.append(self.vocab[token])
- if len(ids) > self.max_len:
- logger.warning(
- "Token indices sequence length is longer than the specified maximum "
- " sequence length for this BERT model ({} > {}). Running this"
- " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
- )
- return ids
-
- def convert_ids_to_tokens(self, ids):
- r"""Converts a sequence of ids in wordpiece tokens using the vocab."""
- tokens = []
- for i in ids:
- tokens.append(self.ids_to_tokens[i])
- return tokens
-
- def save_vocabulary(self, vocab_path):
- r"""Save the tokenizer vocabulary to a directory or file."""
- index = 0
- if os.path.isdir(vocab_path):
- vocab_file = os.path.join(vocab_path, VOCAB_NAME)
- else:
- vocab_file = vocab_path
- with open(vocab_file, "w", encoding="utf-8") as writer:
- for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
- if index != token_index:
- logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
- " Please check that the vocabulary is not corrupted!".format(vocab_file))
- index = token_index
- writer.write(token + u'\n')
- index += 1
- return vocab_file
-
- @classmethod
- def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
- r"""
- 给定模型的名字或者路径,直接读取vocab.
- """
- model_dir = _get_bert_dir(model_dir_or_name)
- pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt')
- logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path))
- max_len = 512
- kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len)
- # Instantiate tokenizer.
- tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs)
- return tokenizer
-
-
-class _WordPieceBertModel(nn.Module):
- r"""
- 这个模块用于直接计算word_piece的结果.
-
- """
-
- def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False):
- super().__init__()
-
- self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
- self.encoder = BertModel.from_pretrained(model_dir_or_name)
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
- self.layers = list(map(int, layers.split(',')))
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a bert model with {encoder_layer_number} layers."
- else:
- assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a bert model with {encoder_layer_number} layers."
-
- self._cls_index = self.tokenzier.vocab['[CLS]']
- self._sep_index = self.tokenzier.vocab['[SEP]']
- self._wordpiece_unknown_index = self.tokenzier.vocab['[UNK]']
- self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
- self.pooled_cls = pooled_cls
-
- def index_dataset(self, *datasets, field_name, add_cls_sep=True):
- r"""
- 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
- [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
-
- :param datasets: DataSet对象
- :param field_name: 基于哪一列index
- :return:
- """
-
- def convert_words_to_word_pieces(words):
- word_pieces = []
- for word in words:
- _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split()
- tokens = []
- for word in _words:
- tokens.extend(self.tokenzier.wordpiece_tokenizer.tokenize(word))
- word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens)
- word_pieces.extend(word_piece_ids)
- if add_cls_sep:
- if word_pieces[0] != self._cls_index:
- word_pieces.insert(0, self._cls_index)
- if word_pieces[-1] != self._sep_index:
- word_pieces.insert(-1, self._sep_index)
- return word_pieces
-
- for index, dataset in enumerate(datasets):
- try:
- dataset.apply_field(convert_words_to_word_pieces, field_name=field_name, new_field_name='word_pieces',
- is_input=True)
- dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
- except Exception as e:
- logger.error(f"Exception happens when processing the {index} dataset.")
- raise e
-
- def forward(self, word_pieces, token_type_ids=None):
- r"""
-
- :param word_pieces: torch.LongTensor, batch_size x max_len
- :param token_type_ids: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- batch_size, max_len = word_pieces.size()
-
- attn_masks = word_pieces.ne(self._wordpiece_pad_index)
- bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
- output_all_encoded_layers=True)
- # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
- outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
- for l_index, l in enumerate(self.layers):
- bert_output = bert_outputs[l]
- if l in (len(bert_outputs)-1, -1) and self.pooled_cls:
- bert_output[:, 0] = pooled_cls
- outputs[l_index] = bert_output
- return outputs
diff --git a/fastNLP/modules/encoder/gpt2.py b/fastNLP/modules/encoder/gpt2.py
index 5b692253..c1d3e2d9 100644
--- a/fastNLP/modules/encoder/gpt2.py
+++ b/fastNLP/modules/encoder/gpt2.py
@@ -1,773 +1,1069 @@
+r"""
-from functools import lru_cache
-import json
-import regex as re
-import itertools
-
+"""
-from ...io.file_utils import _get_embedding_url, cached_path
-from ...core import logger
+from torch import nn
+import torch
+from fastNLP.core import logger
import os
+import copy
+import json
+import math
+from torch.nn import CrossEntropyLoss
+from ..utils import _get_file_name_base_on_postfix
-PRETRAINED_GPT2_MODEL_DIR = PRETRAINED_BERT_MODEL_DIR = {
- 'en-small': 'gpt2-small.zip',
- 'en-median': 'gpt2-medium.zip',
- 'en': 'gpt2-medium.zip'
-}
+from fastNLP.modules.decoder.seq2seq_decoder import Decoder, Past
+from fastNLP.modules.generator.seq2seq_generator import SequenceGenerator
+from typing import Tuple
-def _get_gpt2_dir(model_dir_or_name: str = 'en-median'):
- if model_dir_or_name.lower() in PRETRAINED_GPT2_MODEL_DIR:
- model_url = _get_embedding_url('gpt2', model_dir_or_name.lower())
- model_dir = cached_path(model_url, name='embedding')
- # 检查是否存在
- elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
- model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
- else:
- logger.error(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.")
- raise ValueError(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.")
- return str(model_dir)
+GELU_CONSTANT = math.sqrt(2 / math.pi)
-def _get_filepath_based_on_postfix(folder, postfix):
- """
- 在folder下寻找结尾为postfix的文件. 比如寻找结尾为vocab.txt的文件。只会匹配第一个,如果有多个不会报错,没有找到会报错。
- 返回该文件的全路径
+from ...io.file_utils import _get_gpt2_dir
- :param str folder:
- :param str postfix:
- :return:
- """
- for filename in os.listdir(folder):
- if os.path.isfile(os.path.join(folder, filename)):
- if filename.endswith(postfix):
- return os.path.join(folder, filename)
- raise FileNotFoundError(f"File {postfix} is not found in {folder}.")
+class GPT2Config:
+ """Configuration class to store the configuration of a `GPT2Model`.
-@lru_cache()
-def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a mapping to unicode strings.
- We specifically avoids mapping to whitespace/control characters the bpe code barfs on.
-
- The reversible bpe codes work on unicode strings.
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- This is a signficant percentage of your normal, say, 32K bpe vocab.
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- """
- bs = (
- list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
- )
- cs = bs[:]
- n = 0
- for b in range(2 ** 8):
- if b not in bs:
- bs.append(b)
- cs.append(2 ** 8 + n)
- n += 1
- cs = [chr(n) for n in cs]
- return dict(zip(bs, cs))
-
-
-def get_pairs(word):
- """Return set of symbol pairs in a word.
-
- Word is represented as tuple of symbols (symbols being variable-length strings).
- """
- pairs = set()
- prev_char = word[0]
- for char in word[1:]:
- pairs.add((prev_char, char))
- prev_char = char
- return pairs
-
-
-VOCAB_FILES_NAMES = {
- "vocab_file": "vocab.json",
- "merges_file": "merges.txt",
-}
-
-
-PRETRAINED_VOCAB_FILES_MAP = {
- "vocab_file": {
- "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
- "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
- "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
- "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
- "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
- },
- "merges_file": {
- "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
- "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
- "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
- "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
- "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
- },
-}
-
-
-PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
- "en-small": 1024,
- 'en': 1024,
- "en-medium": 1024,
- "en-large": 1024,
- "en-xl": 1024,
- "en-distilgpt2": 1024,
-}
-
-PATTERN = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
-
-
-def gpt2_tokenize(text, add_prefix_space=True):
+ Args:
+ vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
+ n_positions: Number of positional embeddings.
+ n_ctx: Size of the causal mask (usually same as n_positions).
+ n_embd: Dimensionality of the embeddings and hidden states.
+ n_layer: Number of hidden layers in the Transformer encoder.
+ n_head: Number of attention heads for each attention layer in
+ the Transformer encoder.
+ layer_norm_epsilon: epsilon to use in the layer norm layers
+ resid_pdrop: The dropout probabilitiy for all fully connected
+ layers in the embeddings, encoder, and pooler.
+ attn_pdrop: The dropout ratio for the attention
+ probabilities.
+ embd_pdrop: The dropout ratio for the embeddings.
+ initializer_range: The sttdev of the truncated_normal_initializer for
+ initializing all weight matrices.
"""
- :param str text:
- :param bool add_prefix_space: 是否在句子前面加上space,如果加上才能保证与GPT2训练时一致
- :return: []
- """
- if text is '':
- return []
- if add_prefix_space:
- text = ' ' + text
- tokens = []
- for token in re.findall(PATTERN, text):
- tokens.append(token)
- return tokens
-
-
-class GPT2Tokenizer:
- """
- GPT-2 BPE tokenizer. Peculiarities:
- - Byte-level Byte-Pair-Encoding
- - Requires a space to start the input string => the encoding and tokenize methods should be called with the
- ``add_prefix_space`` flag set to ``True``.
- Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve
- the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"`
- """
-
- vocab_files_names = VOCAB_FILES_NAMES
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
-
- SPECIAL_TOKENS_ATTRIBUTES = [
- "bos_token",
- "eos_token",
- "unk_token",
- "pad_token",
- "cls_token",
- "mask_token",
- ]
-
- padding_side = "right"
-
def __init__(
self,
- vocab_file,
- merges_file,
- errors="replace",
- unk_token="<|endoftext|>",
- bos_token="<|endoftext|>",
- eos_token="<|endoftext|>",
+ vocab_size=50257,
+ n_positions=1024,
+ n_ctx=1024,
+ n_embd=768,
+ n_layer=12,
+ n_head=12,
+ resid_pdrop=0.1,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ summary_type="cls_index",
+ summary_use_proj=True,
+ summary_activation=None,
+ summary_proj_to_labels=True,
+ summary_first_dropout=0.1,
**kwargs
):
- self._bos_token = None
- self._eos_token = None
- self._unk_token = None
- self._sep_token = None
- self._pad_token = None
- self._cls_token = None
- self._mask_token = None
- self._pad_token_type_id = 0
-
- self.bos_token = bos_token
- self.eos_token = eos_token
- self.unk_token = unk_token
-
- self.max_len = int(1e12)
- self.padding_side = kwargs.pop("padding_side", self.padding_side)
- self.added_tokens_encoder = {}
- self.unique_added_tokens_encoder = set()
- self.added_tokens_decoder = {}
- # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
- self.init_inputs = ()
- self.init_kwargs = {}
+ """Constructs GPT2Config.
+ Args:
+ vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
+ n_positions: Number of positional embeddings.
+ n_ctx: Size of the causal mask (usually same as n_positions).
+ n_embd: Dimensionality of the embeddings and hidden states.
+ n_layer: Number of hidden layers in the Transformer encoder.
+ n_head: Number of attention heads for each attention layer in
+ the Transformer encoder.
+ layer_norm_epsilon: epsilon to use in the layer norm layers
+ resid_pdrop: The dropout probabilitiy for all fully connected
+ layers in the embeddings, encoder, and pooler.
+ attn_pdrop: The dropout ratio for the attention
+ probabilities.
+ embd_pdrop: The dropout ratio for the embeddings.
+ initializer_range: The sttdev of the truncated_normal_initializer for
+ initializing all weight matrices.
+ """
+ self.output_attentions = kwargs.pop("output_attentions", False)
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
+ self.output_past = kwargs.pop("output_past", True) # Not used by all models
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
+
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
+ self.is_decoder = kwargs.pop("is_decoder", False)
+
+ # Parameters for sequence generation
+ self.max_length = kwargs.pop("max_length", 20)
+ self.do_sample = kwargs.pop("do_sample", False)
+ self.num_beams = kwargs.pop("num_beams", 1)
+ self.temperature = kwargs.pop("temperature", 1.0)
+ self.top_k = kwargs.pop("top_k", 50)
+ self.top_p = kwargs.pop("top_p", 1.0)
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
+ self.bos_token_id = kwargs.pop("bos_token_id", 0)
+ self.pad_token_id = kwargs.pop("pad_token_id", 0)
+ self.eos_token_ids = kwargs.pop("eos_token_ids", 0)
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
+
+ # Fine-tuning task arguments
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
+ self.num_labels = kwargs.pop("num_labels", 2)
+ self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
+ self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
+ self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
+
+ # Additional attributes without default values
for key, value in kwargs.items():
- if key in self.SPECIAL_TOKENS_ATTRIBUTES:
- if key == "additional_special_tokens":
- assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
- else:
- assert isinstance(value, str)
+ try:
setattr(self, key, value)
+ except AttributeError as err:
+ logger.error("Can't set {} with value {} for {}".format(key, value, self))
+ raise err
+
+ self.vocab_size = vocab_size
+ self.n_ctx = n_ctx
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.summary_type = summary_type
+ self.summary_use_proj = summary_use_proj
+ self.summary_activation = summary_activation
+ self.summary_first_dropout = summary_first_dropout
+ self.summary_proj_to_labels = summary_proj_to_labels
- self.max_len_single_sentence = (
- self.max_len
- ) # no default special tokens - you can update this value if you add special tokens
- self.max_len_sentences_pair = (
- self.max_len
- ) # no default special tokens - you can update this value if you add special tokens
-
- with open(vocab_file, encoding="utf-8") as vocab_handle:
- self.encoder = json.load(vocab_handle)
- self.decoder = {v: k for k, v in self.encoder.items()}
- self.errors = errors # how to handle errors in decoding
- self.byte_encoder = bytes_to_unicode()
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
- with open(merges_file, encoding="utf-8") as merges_handle:
- bpe_merges = merges_handle.read().split("\n")[1:-1]
- bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
- self.cache = {}
-
- def add_special_tokens(self, special_tokens_dict):
- """
- Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
- to class attributes. If special tokens are NOT in the vocabulary, they are added
- to it (indexed starting from the last index of the current vocabulary).
+ @property
+ def max_position_embeddings(self):
+ return self.n_positions
- Using `add_special_tokens` will ensure your special tokens can be used in several ways:
+ @property
+ def hidden_size(self):
+ return self.n_embd
- - special tokens are carefully handled by the tokenizer (they are never split)
- - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts.
+ @property
+ def num_attention_heads(self):
+ return self.n_head
- When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '')
+ @property
+ def num_hidden_layers(self):
+ return self.n_layer
- Args:
- special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
- [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
- ``additional_special_tokens``].
+ def save_pretrained(self, save_directory):
+ """ Save a configuration object to the directory `save_directory`, so that it
+ can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
+ """
+ assert os.path.isdir(
+ save_directory
+ ), "Saving path should be a directory where the model and configuration can be saved"
- Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_config_file = os.path.join(save_directory, 'config.json')
- Returns:
- Number of tokens added to the vocabulary.
+ self.to_json_file(output_config_file)
- Examples::
+ def to_json_file(self, json_file_path):
+ """ Save this instance to a json file."""
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
- # Let's see how to add a new classification token to GPT-2
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- model = GPT2Model.from_pretrained('gpt2')
+ def to_dict(self):
+ """Serializes this instance to a Python dictionary."""
+ output = copy.deepcopy(self.__dict__)
+ return output
- special_tokens_dict = {'cls_token': ''}
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
- num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
- print('We have added', num_added_toks, 'tokens')
- model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `Config` from a json file of parameters."""
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ dict_obj = json.loads(text)
+ return cls(**dict_obj)
- assert tokenizer.cls_token == ''
- """
- if not special_tokens_dict:
- return 0
-
- added_tokens = 0
- for key, value in special_tokens_dict.items():
- assert key in self.SPECIAL_TOKENS_ATTRIBUTES
- if key == "additional_special_tokens":
- assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
- added_tokens += self.add_tokens(value)
- else:
- assert isinstance(value, str)
- added_tokens += self.add_tokens([value])
- logger.debug("Assigning %s to the %s key of the tokenizer", value, key)
- setattr(self, key, value)
-
- return added_tokens
-
- def add_tokens(self, new_tokens):
- """
- Add a list of new tokens to the tokenizer class. If the new tokens are not in the
- vocabulary, they are added to it with indices starting from length of the current vocabulary.
+ @classmethod
+ def from_pretrained(cls, model_dir_or_name, **kwargs):
+ r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
- Args:
- new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
+ Parameters:
+ model_dir_or_name:
- Returns:
- Number of tokens added to the vocabulary.
+ """
+ model_dir = _get_gpt2_dir(model_dir_or_name)
+ tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json')
- Examples::
+ config = cls.from_json_file(tokenizer_config_file)
- # Let's see how to increase the vocabulary of Bert model and tokenizer
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = BertModel.from_pretrained('bert-base-uncased')
+ # if resolved_config_file == config_file:
+ # logger.info("loading configuration file {}".format(config_file))
+ # else:
+ # logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
- num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
- print('We have added', num_added_toks, 'tokens')
- model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
- """
- if not new_tokens:
- return 0
-
- to_add_tokens = []
- for token in new_tokens:
- assert isinstance(token, str)
- if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
- token = token.lower()
- if (
- token != self.unk_token
- and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
- and token not in to_add_tokens
- ):
- to_add_tokens.append(token)
- logger.debug("Adding %s to the vocabulary", token)
-
- added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
- added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
- self.added_tokens_encoder.update(added_tok_encoder)
- self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
- self.added_tokens_decoder.update(added_tok_decoder)
-
- return len(to_add_tokens)
+ if hasattr(config, "pruned_heads"):
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
- @property
- def bos_token(self):
- """ Beginning of sentence token (string). Log an error if used while not having been set. """
- if self._bos_token is None:
- logger.error("Using bos_token, but it is not set yet.")
- return self._bos_token
+ # Update config with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
- @property
- def eos_token(self):
- """ End of sentence token (string). Log an error if used while not having been set. """
- if self._eos_token is None:
- logger.error("Using eos_token, but it is not set yet.")
- return self._eos_token
+ return config
- @property
- def unk_token(self):
- """ Unknown token (string). Log an error if used while not having been set. """
- if self._unk_token is None:
- logger.error("Using unk_token, but it is not set yet.")
- return self._unk_token
- @property
- def pad_token(self):
- """ Padding token (string). Log an error if used while not having been set. """
- if self._pad_token is None:
- logger.error("Using pad_token, but it is not set yet.")
- return self._pad_token
+def gelu(x):
+ return 0.5 * x * (1 + torch.tanh(GELU_CONSTANT * (x + 0.044715 * torch.pow(x, 3))))
- @property
- def cls_token(self):
- """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
- if self._cls_token is None:
- logger.error("Using cls_token, but it is not set yet.")
- return self._cls_token
- @property
- def mask_token(self):
- """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
- if self._mask_token is None:
- logger.error("Using mask_token, but it is not set yet.")
- return self._mask_token
+def prune_conv1d_layer(layer, index, dim=1):
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
+ Return the pruned layer as a new layer with requires_grad=True.
+ Used to remove heads.
+ """
+ index = index.to(layer.weight.device)
+ W = layer.weight.index_select(dim, index).clone().detach()
+ if dim == 0:
+ b = layer.bias.clone().detach()
+ else:
+ b = layer.bias[index].clone().detach()
+ new_size = list(layer.weight.size())
+ new_size[dim] = len(index)
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
+ new_layer.weight.requires_grad = False
+ new_layer.weight.copy_(W.contiguous())
+ new_layer.weight.requires_grad = True
+ new_layer.bias.requires_grad = False
+ new_layer.bias.copy_(b.contiguous())
+ new_layer.bias.requires_grad = True
+ return new_layer
+
+
+class Attention(nn.Module):
+ def __init__(self, nx, n_ctx, config, scale=False):
+ super(Attention, self).__init__()
+
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
+ assert n_state % config.n_head == 0
+ self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
+ self.n_head = config.n_head
+ self.split_size = n_state
+ self.scale = scale
+
+ self.c_attn = Conv1D(n_state * 3, nx)
+ self.c_proj = Conv1D(n_state, nx)
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ mask = torch.ones(self.n_head, self.split_size // self.n_head)
+ heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
+ for head in heads:
+ # Compute how many pruned heads are before the head and move the index accordingly
+ head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
+ mask[head] = 0
+ mask = mask.view(-1).contiguous().eq(1)
+ index = torch.arange(len(mask))[mask].long()
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+
+ # Prune conv1d layers
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+
+ # Update hyper params
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
+ self.n_head = self.n_head - len(heads)
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def _attn(self, q, k, v, attention_mask=None, head_mask=None):
+ w = torch.matmul(q, k) # batch_size x n_head x pre_len x (past_len+pre_len)
+ if self.scale:
+ w = w / math.sqrt(v.size(-1))
+ nd, ns = w.size(-2), w.size(-1)
+ b = self.bias[:, :, ns - nd : ns, :ns] # 1 x 1 x pre_len x (past_len + pre_len)
+ w = w * b - 1e4 * (1 - b) # batch_size x n_head x pre_len x (past_len + pre_len)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ w = w + attention_mask
+
+ w = nn.Softmax(dim=-1)(w)
+ w = self.attn_dropout(w)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ w = w * head_mask
+
+ outputs = [torch.matmul(w, v)]
+ outputs.append(w)
+ return outputs
+
+ def merge_heads(self, x):
+ x = x.permute(0, 2, 1, 3).contiguous()
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
+
+ def split_heads(self, x, k=False):
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
+ if k:
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
+ else:
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
+
+ def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
+ x = self.c_attn(x)
+ query, key, value = x.split(self.split_size, dim=2)
+ query = self.split_heads(query) # (batch, head, seq_length, head_features)
+ key = self.split_heads(key, k=True)
+ value = self.split_heads(value)
+ if layer_past is not None:
+ past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
+ # key: (batch, head, head_features, seq_length)
+ key = torch.cat((past_key, key), dim=-1)
+ # value: (batch, head, seq_length, head_features)
+ value = torch.cat((past_value, value), dim=-2)
+ present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
+
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
+ a = attn_outputs[0]
+
+ a = self.merge_heads(a)
+ a = self.c_proj(a)
+ a = self.resid_dropout(a)
+
+ outputs = [a, present] + attn_outputs[1:]
+ return outputs # a, present, (attentions)
+
+
+class Conv1D(nn.Module):
+ def __init__(self, nf, nx):
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
+ Basically works like a Linear layer but the weights are transposed
+ """
+ super(Conv1D, self).__init__()
+ self.nf = nf
+ w = torch.empty(nx, nf)
+ nn.init.normal_(w, std=0.02)
+ self.weight = nn.Parameter(w)
+ self.bias = nn.Parameter(torch.zeros(nf))
+
+ def forward(self, x):
+ size_out = x.size()[:-1] + (self.nf,)
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
+ x = x.view(*size_out)
+ return x
+
+
+class MLP(nn.Module):
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
+ super(MLP, self).__init__()
+ nx = config.n_embd
+ self.c_fc = Conv1D(n_state, nx)
+ self.c_proj = Conv1D(nx, n_state)
+ self.act = gelu
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, x):
+ h = self.act(self.c_fc(x))
+ h2 = self.c_proj(h)
+ return self.dropout(h2)
+
+
+class Block(nn.Module):
+ def __init__(self, n_ctx, config, scale=False):
+ super(Block, self).__init__()
+ nx = config.n_embd
+ self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
+ self.attn = Attention(nx, n_ctx, config, scale)
+ self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
+ self.mlp = MLP(4 * nx, config)
+
+ def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
+ output_attn = self.attn(
+ self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
+ )
+ a = output_attn[0] # output_attn: a, present, (attentions)
- @bos_token.setter
- def bos_token(self, value):
- self._bos_token = value
+ x = x + a
+ m = self.mlp(self.ln_2(x))
+ x = x + m
- @eos_token.setter
- def eos_token(self, value):
- self._eos_token = value
+ outputs = [x] + output_attn[1:]
+ return outputs # x, present, (attentions)
- @unk_token.setter
- def unk_token(self, value):
- self._unk_token = value
- @pad_token.setter
- def pad_token(self, value):
- self._pad_token = value
+class GPT2PreTrainedModel(nn.Module):
+ """ An abstract class to handle weights initialization and
+ a simple interface for dowloading and loading pretrained models.
+ """
- @cls_token.setter
- def cls_token(self, value):
- self._cls_token = value
+ config_class = GPT2Config
+ base_model_prefix = "transformer"
- @mask_token.setter
- def mask_token(self, value):
- self._mask_token = value
+ def _init_weights(self, module):
+ """ Initialize the weights.
+ """
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__()
+ if not isinstance(config, GPT2Config):
+ raise ValueError(
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
+ "To create a model from a pretrained model use "
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
+ self.__class__.__name__, self.__class__.__name__
+ )
+ )
+ # Save config in model
+ self.config = config
@property
- def bos_token_id(self):
- """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
- return self.convert_tokens_to_ids(self.bos_token)
+ def base_model(self):
+ return getattr(self, self.base_model_prefix, self)
- @property
- def eos_token_id(self):
- """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
- return self.convert_tokens_to_ids(self.eos_token)
+ def get_input_embeddings(self):
+ """ Get model's input embeddings
+ """
+ base_model = getattr(self, self.base_model_prefix, self)
+ if base_model is not self:
+ return base_model.get_input_embeddings()
+ else:
+ raise NotImplementedError
- @property
- def unk_token_id(self):
- """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
- return self.convert_tokens_to_ids(self.unk_token)
+ def set_input_embeddings(self, value):
+ """ Set model's input embeddings
+ """
+ base_model = getattr(self, self.base_model_prefix, self)
+ if base_model is not self:
+ base_model.set_input_embeddings(value)
+ else:
+ raise NotImplementedError
- @property
- def pad_token_id(self):
- """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
- return self.convert_tokens_to_ids(self.pad_token)
+ def get_output_embeddings(self):
+ """ Get model's output embeddings
+ Return None if the model doesn't have output embeddings
+ """
+ return None # Overwrite for models with output embeddings
- @property
- def pad_token_type_id(self):
- """ Id of the padding token type in the vocabulary."""
- return self._pad_token_type_id
+ def tie_weights(self):
+ """ Make sure we are sharing the input and output embeddings.
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
+ """
+ output_embeddings = self.get_output_embeddings()
+ if output_embeddings is not None:
+ self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
- @property
- def cls_token_id(self):
- """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
- return self.convert_tokens_to_ids(self.cls_token)
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
+ """ Tie or clone module weights depending of weither we are using TorchScript or not
+ """
+ if self.config.torchscript:
+ output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
+ else:
+ output_embeddings.weight = input_embeddings.weight
+
+ if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
+ output_embeddings.bias.data = torch.nn.functional.pad(
+ output_embeddings.bias.data,
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
+ "constant",
+ 0,
+ )
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
+ output_embeddings.out_features = input_embeddings.num_embeddings
- @property
- def mask_token_id(self):
- """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
- return self.convert_tokens_to_ids(self.mask_token)
+ def init_weights(self):
+ """ Initialize and prunes weights if needed. """
+ # Initialize weights
+ self.apply(self._init_weights)
- @property
- def vocab_size(self):
- return len(self.encoder)
-
- def bpe(self, token):
- if token in self.cache:
- return self.cache[token]
- word = tuple(token)
- pairs = get_pairs(word)
-
- if not pairs:
- return token
-
- while True:
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
- if bigram not in self.bpe_ranks:
- break
- first, second = bigram
- new_word = []
- i = 0
- while i < len(word):
- try:
- j = word.index(first, i)
- except ValueError:
- new_word.extend(word[i:])
- break
- else:
- new_word.extend(word[i:j])
- i = j
-
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
- new_word.append(first + second)
- i += 2
- else:
- new_word.append(word[i])
- i += 1
- new_word = tuple(new_word)
- word = new_word
- if len(word) == 1:
- break
- else:
- pairs = get_pairs(word)
- word = " ".join(word)
- self.cache[token] = word
- return word
-
- def _tokenize(self, text, add_prefix_space=False):
- """ Tokenize a string.
- Args:
- - add_prefix_space (boolean, default False):
- Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
+ # Prune heads if needed
+ if self.config.pruned_heads:
+ self.prune_heads(self.config.pruned_heads)
+
+ # Tie weights if needed
+ self.tie_weights()
+
+ def prune_heads(self, heads_to_prune):
+ """ Prunes heads of the base model.
+
+ Arguments:
+
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
- bpe_tokens = []
- for token in gpt2_tokenize(text, add_prefix_space=add_prefix_space):
- token = "".join(
- self.byte_encoder[b] for b in token.encode("utf-8")
- ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
- bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
- return bpe_tokens
-
- def _convert_token_to_id(self, token):
- """ Converts a token (str) in an id using the vocab. """
- return self.encoder.get(token, self.encoder.get(self.unk_token))
-
- def _convert_id_to_token(self, index):
- """Converts an index (integer) in a token (str) using the vocab."""
- return self.decoder.get(index)
-
- def convert_tokens_to_string(self, tokens):
- """ Converts a sequence of tokens (string) in a single string. """
- text = "".join(tokens)
- text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
- return text
-
- def save_vocabulary(self, save_directory):
- """Save the tokenizer vocabulary and merge files to a directory."""
- if not os.path.isdir(save_directory):
- logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
- return
- vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
- merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
-
- with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
-
- index = 0
- with open(merge_file, "w", encoding="utf-8") as writer:
- writer.write("#version: 0.2\n")
- for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
- if index != token_index:
- logger.warning(
- "Saving vocabulary to {}: BPE merge indices are not consecutive."
- " Please check that the tokenizer is not corrupted!".format(merge_file)
- )
- index = token_index
- writer.write(" ".join(bpe_tokens) + "\n")
- index += 1
-
- return vocab_file, merge_file
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
+ for layer, heads in heads_to_prune.items():
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
- @classmethod
- def from_pretrained(cls, model_dir_or_name):
- r"""
+ self.base_model._prune_heads(heads_to_prune)
+
+ def save_pretrained(self, save_directory):
+ """ Save a model and its configuration file to a directory, so that it
+ can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
"""
- return cls._from_pretrained(model_dir_or_name)
+ assert os.path.isdir(
+ save_directory
+ ), "Saving path should be a directory where the model and configuration can be saved"
+
+ # Only save the model itself if we are using distributed training
+ model_to_save = self.module if hasattr(self, "module") else self
+
+ # Save configuration file
+ model_to_save.config.save_pretrained(save_directory)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_model_file = os.path.join(save_directory, "pytorch_model.bin")
+ torch.save(model_to_save.state_dict(), output_model_file)
+ logger.info("Model weights saved in {}".format(output_model_file))
- # 将它修改一定传入文件夹
@classmethod
- def _from_pretrained(cls, model_dir_or_name):
- """
+ def from_pretrained(cls, model_dir_or_name, *model_args, **kwargs):
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
+ To train the model, you should first set it back in training mode with ``model.train()``
+
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
+ It is up to you to train those weights with a downstream fine-tuning task.
+
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
+
+ Parameters:
+ model_dir_or_name: either:
+
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
+ - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
+ - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+ - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
+
+ Examples::
+
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
+ assert model.config.output_attention == True
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
- :param str model_dir_or_name: 目录或者缩写名
- :param init_inputs:
- :param kwargs:
- :return:
"""
- # 它需要两个文件,第一个是vocab.json,第二个是merge_file?
+ config = kwargs.pop("config", None)
+ state_dict = kwargs.pop("state_dict", None)
+
model_dir = _get_gpt2_dir(model_dir_or_name)
- # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin
-
- tokenizer_config_file = _get_filepath_based_on_postfix(model_dir, 'config.json')
- with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
- init_kwargs = json.load(tokenizer_config_handle)
- # Set max length if needed
- if model_dir_or_name in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
- # if we're using a pretrained model, ensure the tokenizer
- # wont index sequences longer than the number of positional embeddings
- max_len = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name]
- if max_len is not None and isinstance(max_len, (int, float)):
- init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
-
- # 将vocab, merge加入到init_kwargs中
- init_kwargs['vocab_file'] = _get_filepath_based_on_postfix(model_dir, 'vocab.json')
- init_kwargs['merges_file'] = _get_filepath_based_on_postfix(model_dir, 'merges.txt')
-
- init_inputs = init_kwargs.pop("init_inputs", ())
- # Instantiate tokenizer.
- try:
- tokenizer = cls(*init_inputs, **init_kwargs)
- except OSError:
- OSError(
- "Unable to load vocabulary from file. "
- "Please check that the provided vocabulary is accessible and not corrupted."
+
+ # Load config if we don't provide a configuration
+ model_kwargs = {}
+ if not isinstance(config, GPT2Config):
+ config = cls.config_class.from_pretrained(
+ model_dir,
+ *model_args,
+ **kwargs
+ )
+ else:
+ model_kwargs = kwargs
+
+ # Instantiate model.
+ model = cls(config, *model_args, **model_kwargs)
+
+ model_path = _get_file_name_base_on_postfix(model_dir, 'model.bin')
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+
+ # Convert old format to new format if needed from a PyTorch state_dict
+ old_keys = []
+ new_keys = []
+ for key in state_dict.keys():
+ new_key = None
+ if "gamma" in key:
+ new_key = key.replace("gamma", "weight")
+ if "beta" in key:
+ new_key = key.replace("beta", "bias")
+ if new_key:
+ old_keys.append(key)
+ new_keys.append(new_key)
+ for old_key, new_key in zip(old_keys, new_keys):
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module, prefix=""):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
+ )
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ start_prefix = ""
+ model_to_load = model
+ if not hasattr(model, cls.base_model_prefix) and any(
+ s.startswith(cls.base_model_prefix) for s in state_dict.keys()
+ ):
+ start_prefix = cls.base_model_prefix + "."
+ if hasattr(model, cls.base_model_prefix) and not any(
+ s.startswith(cls.base_model_prefix) for s in state_dict.keys()
+ ):
+ model_to_load = getattr(model, cls.base_model_prefix)
+
+ load(model_to_load, prefix=start_prefix)
+ if len(missing_keys) > 0:
+ logger.info(
+ "Weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, missing_keys
+ )
+ )
+ if len(unexpected_keys) > 0:
+ logger.info(
+ "Weights from pretrained model not used in {}: {}".format(
+ model.__class__.__name__, unexpected_keys
+ )
+ )
+ if len(error_msgs) > 0:
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
+ model.__class__.__name__, "\n\t".join(error_msgs)
+ )
)
- return tokenizer
+ model.tie_weights() # make sure word embedding weights are still tied if needed
- def __len__(self):
- """ Size of the full vocabulary with the added tokens """
- return self.vocab_size + len(self.added_tokens_encoder)
-
- def tokenize(self, text, add_prefix_space=True):
- """ Converts a string in a sequence of tokens (string), using the tokenizer.
- Split in words for word-based vocabulary or sub-words for sub-word-based
- vocabularies (BPE/SentencePieces/WordPieces).
-
- Take care of added tokens.
- Args:
- - text: The sequence to be encoded.
- - add_prefix_space (boolean, default True):
- Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
- """
- all_special_tokens = self.all_special_tokens
-
- def lowercase_text(t):
- # convert non-special tokens to lowercase
- escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
- pattern = r'(' + r'|'.join(escaped_special_toks) + r')|' + \
- r'(.+?)'
- return re.sub(
- pattern,
- lambda m: m.groups()[0] or m.groups()[1].lower(),
- t)
-
- if self.init_kwargs.get('do_lower_case', False):
- text = lowercase_text(text)
-
- def split_on_token(tok, text):
- result = []
- split_text = text.split(tok)
- for i, sub_text in enumerate(split_text):
- sub_text = sub_text.strip()
- if i == 0 and not sub_text:
- result += [tok]
- elif i == len(split_text) - 1:
- if sub_text:
- result += [sub_text]
- else:
- pass
- else:
- if sub_text:
- result += [sub_text]
- result += [tok]
- return result
-
- def split_on_tokens(tok_list, text):
- if not text.strip():
- return []
- if not tok_list:
- return self._tokenize(text, add_prefix_space=add_prefix_space)
-
- tokenized_text = []
- text_list = [text]
- for tok in tok_list:
- tokenized_text = []
- for sub_text in text_list:
- if sub_text not in self.added_tokens_encoder \
- and sub_text not in all_special_tokens:
- tokenized_text += split_on_token(tok, sub_text)
- else:
- tokenized_text += [sub_text]
- text_list = tokenized_text
-
- return list(itertools.chain.from_iterable((self._tokenize(token, add_prefix_space=add_prefix_space) if token not \
- in self.added_tokens_encoder and token not in all_special_tokens \
- else [token] for token in tokenized_text)))
-
- added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
- tokenized_text = split_on_tokens(added_tokens, text)
- return tokenized_text
-
- def convert_tokens_to_ids(self, tokens):
- """ Converts a single token, or a sequence of tokens, (str) in a single integer id
- (resp. a sequence of ids), using the vocabulary.
+ # Set model in evaluation mode to desactivate DropOut modules by default
+ model.eval()
+
+ return model
+
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
+ return {"input_ids": input_ids, **kwargs}
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids,
+ max_length=None,
+ do_sample=None,
+ num_beams=None,
+ temperature=None,
+ top_k=None,
+ top_p=None,
+ repetition_penalty=None,
+ bos_token_id=None,
+ pad_token_id=None,
+ eos_token_ids=None,
+ length_penalty=None):
+ """ Sequence generator for models with a LM head.
+
+ The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
+ and beam-search.
+
+ Params:
+ **input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
+ The sequence used as a prompt for the generation. If `None` the method initializes
+ it as an empty `torch.LongTensor` of shape (1,)
+ **max_length**: (`optional`) int
+ The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
+ **do_sample**: (`optional`) bool
+ If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
+ **num_beams**: (`optional`) int
+ Number of beams for beam search. 1 means no beam serach. Default to 1.
+ **temperature**: (`optional`) float
+ The value used to module the next token probabilities.
+ **top_k**: (`optional`) int
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
+ **top_p**: (`optional`) float
+ The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
+ **repetition_penalty**: (`optional`) float
+ The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
+ **bos_token_id**: (`optional`) int
+ Beginning of sentence token if no prompt is provided. Default to 0.
+ **eos_token_ids**: (`optional`) int or list of int
+ End of sequence token or list of tokens to stop the generation. Default to 0.
+ **length_penalty**: (`optional`) int
+ Exponential penalty to the length. Default to 0.
+ **length_penalty**: (`optional`) float
+ Exponential penalty to the length. Default to 1.
"""
- if tokens is None:
- return None
+ decoder = _GPT2Decoder(self)
+ generator = SequenceGenerator(decoder=decoder, max_length=max_length, num_beams=num_beams,
+ do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p,
+ bos_token_id=bos_token_id, eos_token_id=eos_token_ids,
+ repetition_penalty=repetition_penalty, length_penalty=length_penalty,
+ pad_token_id=pad_token_id)
+ results = generator.generate(input_ids, past=None)
+ return results
+
+
+class GPT2Model(GPT2PreTrainedModel):
+ r"""
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
+ Sequence of hidden-states at the last layer of the model.
+ **past**:
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(2, batch_size, num_heads, sequence_length, embed_size_per_head)``:
+ that contains pre-computed hidden-states (key and values in the attention blocks).
+ Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
+ should not be passed as input ids as they have already been computed.
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+
+ Examples::
+
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ model = GPT2Model.from_pretrained('gpt2')
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
+ outputs = model(input_ids)
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
- if isinstance(tokens, str):
- return self._convert_token_to_id_with_added_voc(tokens)
+ """
- ids = []
- for token in tokens:
- ids.append(self._convert_token_to_id_with_added_voc(token))
- return ids
+ def __init__(self, config):
+ super().__init__(config)
- def _convert_token_to_id_with_added_voc(self, token):
- if token is None:
- return None
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
- if token in self.added_tokens_encoder:
- return self.added_tokens_encoder[token]
- return self._convert_token_to_id(token)
+ self.init_weights()
- def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
- """ Converts a single index or a sequence of indices (integers) in a token "
- (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
+ def get_input_embeddings(self):
+ return self.wte
- Args:
- skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """ Prunes heads of the model.
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
- if isinstance(ids, int):
- return self._convert_id_to_token(ids)
- tokens = []
- for index in ids:
- index = int(index)
- if skip_special_tokens and index in self.all_special_ids:
- continue
- tokens.append(self._convert_id_to_token(index))
- return tokens
-
- def convert_id_to_tokens(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
+ for layer, heads in heads_to_prune.items():
+ self.h[layer].attn.prune_heads(heads)
+
+ def forward(
+ self,
+ input_ids,
+ past=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ output_attentions=True
+ ):
"""
- Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
- with options to remove special tokens and clean up tokenization spaces.
- Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
- Args:
- token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
- skip_special_tokens: if set to True, will replace special tokens.
- clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
+ :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1
+ :param GPT2Past past: 之前的状态
+ :param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与past的concat一样大。
+ 为0的地方为padding。
+ :param torch.LongTensor token_type_ids: batch_size x max_len。
+ :param torch.LongTensor position_ids: 与input_ids对应的位置
+ :param head_mask:
+ :param bool output_attentions: 是否输出attention状态
+ :return:
"""
- filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
-
- # To avoid mixing byte-level and unicode for byte-level BPT
- # we need to build string separatly for added tokens and byte-level tokens
- # cf. https://github.com/huggingface/transformers/issues/1133
- sub_texts = []
- current_sub_text = []
- for token in filtered_tokens:
- if skip_special_tokens and token in self.all_special_ids:
- continue
- if token in self.added_tokens_encoder:
- if current_sub_text:
- sub_texts.append(self.convert_tokens_to_string(current_sub_text))
- current_sub_text = []
- sub_texts.append(token)
- else:
- current_sub_text.append(token)
- if current_sub_text:
- sub_texts.append(self.convert_tokens_to_string(current_sub_text))
- text = " ".join(sub_texts)
-
- if clean_up_tokenization_spaces:
- clean_text = self.clean_up_tokenization(text)
- return clean_text
+ input_shape = input_ids.size() # batch_size x max_len 或 batch_size x beam_size x 1
+ input_ids = input_ids.view(-1, input_shape[-1]) # input_shape是 batch_size' x max_len
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+
+ if past is None or len(past)==0:
+ past_length = 0
+ past = [None] * len(self.h) # len(self.h) 是layer的层数
+ else:
+ past_length = past[0][0].size(-2)
+ if position_ids is None: # 如果没有position id则生成
+ device = input_ids.device
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ # Attention mask.
+ if attention_mask is not None:
+ attention_mask = attention_mask.view(-1, input_shape[-1])
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * -10000.0
+ # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ if head_mask is not None:
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = (
+ head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
+ ) # We can specify head_mask for each layer
+ head_mask = head_mask.to(
+ dtype=next(self.parameters()).dtype
+ ) # switch to fload if need + fp16 compatibility
else:
- return text
+ head_mask = [None] * self.config.n_layer
- @property
- def special_tokens_map(self):
- """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
- values ('', ''...)
- """
- set_attr = {}
- for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
- attr_value = getattr(self, "_" + attr)
- if attr_value:
- set_attr[attr] = attr_value
- return set_attr
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ else:
+ token_type_embeds = 0
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
+ hidden_states = self.drop(hidden_states)
- @property
- def all_special_tokens(self):
- """ List all the special tokens ('', ''...) mapped to class attributes
- (cls_token, unk_token...).
- """
- all_toks = []
- set_attr = self.special_tokens_map
- for attr_value in set_attr.values():
- all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
- all_toks = list(set(all_toks))
- return all_toks
+ # batch_size x max_len x embed_size
+ output_shape = input_shape + (hidden_states.size(-1),)
- @property
- def all_special_ids(self):
- """ List the vocabulary indices of the special tokens ('', ''...) mapped to
- class attributes (cls_token, unk_token...).
+ presents = ()
+ all_attentions = []
+ all_hidden_states = ()
+ for i, (block, layer_past) in enumerate(zip(self.h, past)):
+ all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
+
+ outputs = block(
+ hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
+ )
+
+ hidden_states, present = outputs[:2]
+ presents = presents + (present,)
+
+ all_attentions.append(outputs[2])
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(*output_shape)
+ # Add last hidden state
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = (hidden_states,)
+ outputs = outputs + (presents,)
+
+ outputs = outputs + (all_hidden_states,)
+ if output_attentions:
+ # let the number of heads free (-1) so we can extract attention even after head pruning
+ attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
+ all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
+ outputs = outputs + (all_attentions,)
+ # 写出所有输出的shape.
+ # last hidden states, Tensor: batch_size x max_len x embed_size
+ # presents, tuple: n_layer x 2 x batch_size x n_head x (max_len+past_len) x head_dim, 第二维前一半为key,后一半为value
+ # all hidden states, tuple: n_layer x batch_size x max_len x embed_size,
+ # attention, tuple: n_layer x batch_size x n_head' x src_len x tgt_len
+ return outputs # last hidden state, (presents), (all hidden_states), (attentions)
+
+
+class GPT2Past(Past):
+ def __init__(self):
+ super().__init__()
+ self.past = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim]
+
+ def num_samples(self):
+ if self.past is not None:
+ return self.past[0].size(1)
+ return None
+
+ def reorder_past(self, indices):
+ for i in range(len(self.past)):
+ assert self.past[i] is not None
+ self.past[i] = self.past[i].index_select(index=indices, dim=1)
+
+ def __iter__(self):
+ for p in self.past:
+ yield p
+
+ def __getitem__(self, item):
+ assert isinstance(item, int)
+ return self.past[item]
+
+ def __len__(self):
+ if self.past is not None:
+ return len(self.past)
+ return 0
+
+
+class _GPT2Decoder(Decoder):
+ def __init__(self, gpt_model):
+ super().__init__()
+ self.gpt_model = gpt_model
+
+ def decode(self, tokens, past=None) -> Tuple[torch.Tensor, Past]:
+ if past is None:
+ past = GPT2Past()
+ lm_logits, presents, _ = self.gpt_model(input_ids=tokens,
+ past=past,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ output_attentions=False)
+ past.past = list(presents)
+ return lm_logits[:, -1], past
+
+ def reorder_past(self, indices: torch.LongTensor, past: GPT2Past) -> GPT2Past:
+ past.reorder_past(indices)
+ return past
+
+
+class GPT2LMHeadModel(GPT2PreTrainedModel):
+ r"""
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
+ Labels for language modeling.
+ Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
+ Indices are selected in ``[-1, 0, ..., config.vocab_size]``
+ All labels set to ``-100`` are ignored (masked), the loss is only
+ computed for labels in ``[0, ..., config.vocab_size]``
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Language modeling loss.
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ **past**:
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(2, batch_size, num_heads, sequence_length, embed_size_per_head)``:
+ that contains pre-computed hidden-states (key and values in the attention blocks).
+ Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
+ should not be passed as input ids as they have already been computed.
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
+ of shape ``(batch_size, sequence_length, hidden_size)``:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
+ """
+
+ def __init__(self, config):
+ super(GPT2LMHeadModel, self).__init__(config)
+ self.transformer = GPT2Model(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def get_input_embeddings(self):
+ return self.transformer.wte
+
+ def forward(
+ self,
+ input_ids,
+ past=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ labels=None,
+ output_attentions=False
+ ):
"""
- all_toks = self.all_special_tokens
- all_ids = self.convert_tokens_to_ids(all_toks)
- return all_ids
- @staticmethod
- def clean_up_tokenization(out_string):
- """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
+ :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1
+ :param tuple past: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入
+ :param torch.ByteTensor attention_mask: batch_size x max_len, 与input_ids一样大。为0的地方为padding。
+ :param torch.LongTensor token_type_ids: batch_size x max_len。
+ :param torch.LongTensor position_ids: 与input_ids对应的位置
+ :param head_mask:
+ :param labels: language model应该预测的值。如果为None,则没有language model的额外loss。最好把padding位置设置为-100
+ 使得language model不要计算这部分的loss
+ :param output_attentions: 是否输出output_attentions
+ :return:
"""
- out_string = (
- out_string.replace(" .", ".")
- .replace(" ?", "?")
- .replace(" !", "!")
- .replace(" ,", ",")
- .replace(" ' ", "'")
- .replace(" n't", "n't")
- .replace(" 'm", "'m")
- .replace(" do not", " don't")
- .replace(" 's", "'s")
- .replace(" 've", "'ve")
- .replace(" 're", "'re")
+ transformer_outputs = self.transformer(
+ input_ids,
+ past=past,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ output_attentions=output_attentions
)
- return out_string
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ outputs = (lm_logits,) + transformer_outputs[1:]
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+ outputs = (loss,) + outputs
+
+ # 返回值
+ # loss: torch.FloatTensor, 如果labels为None则没有该loss
+ # lm_logits: batch_size x max_len x vocab_size
+ # presents, tuple: n_layer x 2 x batch_size x n_head x (max_len+past_len) x head_dim, 第二维前一半为key,后一半为value
+ # all hidden states, tuple: n_layer x batch_size x max_len x embed_size,
+ # attention, tuple: n_layer x batch_size x n_head' x src_len x tgt_len
+ return outputs # (loss), lm_logits, presents, all hidden_states, (attentions)
+
+
+
+
+
+# 输出每个位置的
+
diff --git a/fastNLP/modules/encoder/roberta.py b/fastNLP/modules/encoder/roberta.py
index af8795c6..02b9df42 100644
--- a/fastNLP/modules/encoder/roberta.py
+++ b/fastNLP/modules/encoder/roberta.py
@@ -1,13 +1,19 @@
-from typing import List, Optional
-import json
+r"""undocumented
+这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
+ 有用,也请引用一下他们。
+"""
+
+__all__ = [
+ 'RobertaModel'
+]
import torch
import torch.nn as nn
-from .bert import BertEmbeddings, BertModel, BertConfig, _get_bert_dir
-from .gpt2 import GPT2Tokenizer
-from ..utils import create_position_ids_from_input_ids, _get_file_name_base_on_postfix
+from .bert import BertEmbeddings, BertModel, BertConfig
+from ..utils import _get_file_name_base_on_postfix
+from ...io.file_utils import _get_roberta_dir
from ...core import logger
PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = {
@@ -33,30 +39,24 @@ class RobertaEmbeddings(BertEmbeddings):
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, words_embeddings=None):
- if position_ids is None:
- if input_ids is not None:
- # Create the position ids from the input token ids. Any padded tokens remain padded.
- position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
- else:
- position_ids = self.create_position_ids_from_inputs_embeds(words_embeddings)
+ def forward(self, input_ids, token_type_ids, words_embeddings=None):
+ position_ids = self.create_position_ids_from_input_ids(input_ids)
return super().forward(
input_ids, token_type_ids=token_type_ids, position_ids=position_ids, words_embeddings=words_embeddings
)
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
- """
- :param torch.Tensor inputs_embeds:
+ def create_position_ids_from_input_ids(self, x):
+ """ Replace non-padding symbols with their position numbers. Position numbers begin at
+ padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
+ `utils.make_positions`.
+
+ :param torch.Tensor x:
:return torch.Tensor:
"""
- input_shape = inputs_embeds.size()[:-1]
- sequence_length = input_shape[1]
-
- position_ids = torch.arange(
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
- )
- return position_ids.unsqueeze(0).expand(input_shape)
+ mask = x.ne(self.padding_idx).long()
+ incremental_indicies = torch.cumsum(mask, dim=1) * mask
+ return incremental_indicies + self.padding_idx
class RobertaModel(BertModel):
@@ -70,12 +70,6 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config)
self.apply(self.init_bert_weights)
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
-
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
-
@classmethod
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
state_dict = kwargs.get('state_dict', None)
@@ -84,7 +78,7 @@ class RobertaModel(BertModel):
kwargs.pop('from_tf', None)
# get model dir from name or dir
- pretrained_model_dir = _get_bert_dir(model_dir_or_name)
+ pretrained_model_dir = _get_roberta_dir(model_dir_or_name)
# Load config
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, 'config.json')
@@ -186,172 +180,3 @@ class RobertaModel(BertModel):
return model
-class RobertaTokenizer(GPT2Tokenizer):
-
- vocab_files_names = {
- "vocab_file": "vocab.json",
- "merges_file": "merges.txt",
- }
-
- def __init__(
- self,
- vocab_file,
- merges_file,
- errors="replace",
- bos_token="",
- eos_token="",
- sep_token="",
- cls_token="",
- unk_token="",
- pad_token="",
- mask_token="",
- **kwargs
- ):
- super().__init__(
- vocab_file=vocab_file,
- merges_file=merges_file,
- errors=errors,
- bos_token=bos_token,
- eos_token=eos_token,
- unk_token=unk_token,
- sep_token=sep_token,
- cls_token=cls_token,
- pad_token=pad_token,
- mask_token=mask_token,
- **kwargs,
- )
- self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
- self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
-
- def build_inputs_with_special_tokens(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
- ) -> List[int]:
- """
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks
- by concatenating and adding special tokens.
- A RoBERTa sequence has the following format:
-
- - single sequence: `` X ``
- - pair of sequences: `` A B ``
-
- Args:
- token_ids_0 (:obj:`List[int]`):
- List of IDs to which the special tokens will be added
- token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
- Optional second list of IDs for sequence pairs.
-
- Returns:
- :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
- """
- if token_ids_1 is None:
- return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
- cls = [self.cls_token_id]
- sep = [self.sep_token_id]
- return cls + token_ids_0 + sep + sep + token_ids_1 + sep
-
- def get_special_tokens_mask(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
- ) -> List[int]:
- """
- Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
- special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
-
- Args:
- token_ids_0 (:obj:`List[int]`):
- List of ids.
- token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
- Optional second list of IDs for sequence pairs.
- already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Set to True if the token list is already formatted with special tokens for the model
-
- Returns:
- :obj:`List[int]`: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
- """
- if already_has_special_tokens:
- if token_ids_1 is not None:
- raise ValueError(
- "You should not supply a second sequence if the provided sequence of "
- "ids is already formated with special tokens for the model."
- )
- return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
-
- if token_ids_1 is None:
- return [1] + ([0] * len(token_ids_0)) + [1]
- return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
-
- def create_token_type_ids_from_sequences(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
- ) -> List[int]:
- """
- Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
- RoBERTa does not make use of token type ids, therefore a list of zeros is returned.
-
- Args:
- token_ids_0 (:obj:`List[int]`):
- List of ids.
- token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
- Optional second list of IDs for sequence pairs.
-
- Returns:
- :obj:`List[int]`: List of zeros.
-
- """
- sep = [self.sep_token_id]
- cls = [self.cls_token_id]
-
- if token_ids_1 is None:
- return len(cls + token_ids_0 + sep) * [0]
- return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
-
- def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs):
- if "add_prefix_space" in kwargs:
- add_prefix_space = kwargs["add_prefix_space"]
- else:
- add_prefix_space = add_special_tokens
- if add_prefix_space and not text[0].isspace():
- text = " " + text
- return text
-
- @classmethod
- def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
- """
-
- :param str model_dir_or_name: 目录或者缩写名
- :param kwargs:
- :return:
- """
- # 它需要两个文件,第一个是vocab.json,第二个是merge_file?
- model_dir = _get_bert_dir(model_dir_or_name)
- # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin
-
- tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json')
- with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
- init_kwargs = json.load(tokenizer_config_handle)
- # Set max length if needed
- if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES:
- # if we're using a pretrained model, ensure the tokenizer
- # wont index sequences longer than the number of positional embeddings
- max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name]
- if max_len is not None and isinstance(max_len, (int, float)):
- init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
-
- # 将vocab, merge加入到init_kwargs中
- if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表
- init_kwargs['vocab_file'] = kwargs['vocab_file']
- else:
- init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, 'vocab.json')
- init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, 'merges.txt')
-
- init_inputs = init_kwargs.pop("init_inputs", ())
- # Instantiate tokenizer.
- try:
- tokenizer = cls(*init_inputs, **init_kwargs)
- except OSError:
- OSError(
- "Unable to load vocabulary from file. "
- "Please check that the provided vocabulary is accessible and not corrupted."
- )
-
- return tokenizer
-
-
diff --git a/fastNLP/modules/generator/__init__.py b/fastNLP/modules/generator/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/fastNLP/modules/generator/seq2seq_generator.py b/fastNLP/modules/generator/seq2seq_generator.py
new file mode 100755
index 00000000..d332cc2f
--- /dev/null
+++ b/fastNLP/modules/generator/seq2seq_generator.py
@@ -0,0 +1,444 @@
+import torch
+from ..decoder.seq2seq_decoder import Decoder
+import torch.nn.functional as F
+from fastNLP.core.utils import _get_model_device
+from functools import partial
+
+
+class SequenceGenerator:
+ def __init__(self, decoder: Decoder, max_length=20, num_beams=1,
+ do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
+ repetition_penalty=1, length_penalty=1.0, pad_token_id=0):
+ if do_sample:
+ self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
+ temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty, pad_token_id=pad_token_id)
+ else:
+ self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty, pad_token_id=pad_token_id)
+ self.do_sample = do_sample
+ self.max_length = max_length
+ self.num_beams = num_beams
+ self.temperature = temperature
+ self.top_k = top_k
+ self.top_p = top_p
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.repetition_penalty = repetition_penalty
+ self.length_penalty = length_penalty
+ self.decoder = decoder
+
+ @torch.no_grad()
+ def generate(self, tokens=None, past=None):
+ """
+
+ :param torch.LongTensor tokens: batch_size x length, 开始的token
+ :param past:
+ :return:
+ """
+ # TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode?
+ return self.generate_func(tokens=tokens, past=past)
+
+
+@torch.no_grad()
+def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
+ bos_token_id=None, eos_token_id=None, pad_token_id=0,
+ repetition_penalty=1, length_penalty=1.0):
+ """
+ 贪婪地搜索句子
+
+ :param Decoder decoder: Decoder对象
+ :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
+ :param Past past: 应该包好encoder的一些输出。
+ :param int max_length: 生成句子的最大长度。
+ :param int num_beams: 使用多大的beam进行解码。
+ :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
+ :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。
+ :param int pad_token_id:
+ :param float repetition_penalty: 对重复出现的token多大的惩罚。
+ :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。
+ :return:
+ """
+ if num_beams == 1:
+ token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1,
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
+ repetition_penalty=repetition_penalty, length_penalty=length_penalty,
+ pad_token_id=pad_token_id)
+ else:
+ token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
+ temperature=1, top_k=50, top_p=1,
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
+ repetition_penalty=repetition_penalty, length_penalty=length_penalty,
+ pad_token_id=pad_token_id)
+
+ return token_ids
+
+
+@torch.no_grad()
+def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50,
+ top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0,
+ length_penalty=1.0):
+ """
+ 使用采样的方法生成句子
+
+ :param Decoder decoder: Decoder对象
+ :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
+ :param Past past: 应该包好encoder的一些输出。
+ :param int max_length: 生成句子的最大长度。
+ :param int num_beam: 使用多大的beam进行解码。
+ :param float temperature: 采样时的退火大小
+ :param int top_k: 只在top_k的sample里面采样
+ :param float top_p: 介于0,1的值。
+ :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
+ :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。
+ :param int pad_token_id: pad的token id
+ :param float repetition_penalty: 对重复出现的token多大的惩罚。
+ :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。
+ :return:
+ """
+ # 每个位置在生成的时候会sample生成
+ if num_beams == 1:
+ token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature,
+ top_k=top_k, top_p=top_p,
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
+ repetition_penalty=repetition_penalty, length_penalty=length_penalty,
+ pad_token_id=pad_token_id)
+ else:
+ token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
+ temperature=temperature, top_k=top_k, top_p=top_p,
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
+ repetition_penalty=repetition_penalty, length_penalty=length_penalty,
+ pad_token_id=pad_token_id)
+ return token_ids
+
+
+def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50,
+ top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
+ repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0):
+ device = _get_model_device(decoder)
+ if tokens is None:
+ if bos_token_id is None:
+ raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
+ if past is None:
+ raise RuntimeError("You have to specify either `past` or `tokens`.")
+ batch_size = past.num_samples()
+ if batch_size is None:
+ raise RuntimeError("Cannot infer the number of samples from `past`.")
+ tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
+ batch_size = tokens.size(0)
+ if past is not None:
+ assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
+
+ if eos_token_id is None:
+ _eos_token_id = float('nan')
+ else:
+ _eos_token_id = eos_token_id
+
+ # for i in range(tokens.size(1)):
+ # scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past
+ scores, past = decoder.decode(tokens, past)
+
+ token_ids = tokens.clone()
+ cur_len = token_ids.size(1)
+ dones = token_ids.new_zeros(batch_size).eq(1)
+ # tokens = tokens[:, -1:]
+
+ while cur_len < max_length:
+ # scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past
+ scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past
+
+ if repetition_penalty != 1.0:
+ token_scores = scores.gather(dim=1, index=token_ids)
+ lt_zero_mask = token_scores.lt(0).float()
+ ge_zero_mask = lt_zero_mask.eq(0).float()
+ token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
+ scores.scatter_(dim=1, index=token_ids, src=token_scores)
+
+ if eos_token_id is not None and length_penalty != 1.0:
+ token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size
+ eos_mask = scores.new_ones(scores.size(1))
+ eos_mask[eos_token_id] = 0
+ eos_mask = eos_mask.unsqueeze(0).eq(1)
+ scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小
+
+ if do_sample:
+ if temperature > 0 and temperature != 1:
+ scores = scores / temperature
+
+ scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
+ probs = F.softmax(scores, dim=-1)
+
+ # 保证至少有一个不是eos的值
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
+ else:
+ next_tokens = torch.argmax(scores, dim=-1) # batch_size
+
+ next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding
+ tokens = next_tokens.unsqueeze(1)
+
+ token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
+
+ end_mask = next_tokens.eq(_eos_token_id)
+ dones = dones.__or__(end_mask)
+ cur_len += 1
+
+ if dones.min() == 1:
+ break
+
+ if eos_token_id is not None:
+ if cur_len == max_length:
+ token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos
+
+ return token_ids
+
+
+def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0,
+ top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
+ repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor:
+ # 进行beam search
+ device = _get_model_device(decoder)
+ if tokens is None:
+ if bos_token_id is None:
+ raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
+ if past is None:
+ raise RuntimeError("You have to specify either `past` or `tokens`.")
+ batch_size = past.num_samples()
+ if batch_size is None:
+ raise RuntimeError("Cannot infer the number of samples from `past`.")
+ tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
+ batch_size = tokens.size(0)
+ if past is not None:
+ assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
+
+ # for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode
+ # scores, past = decoder.decode_one(tokens[:, :i + 1],
+ # past) # (batch_size, vocab_size), Past
+ # scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度
+ scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度
+ vocab_size = scores.size(1)
+ assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
+
+ if do_sample:
+ probs = F.softmax(scores, dim=-1)
+ next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams)
+ logits = probs.log()
+ next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams)
+ else:
+ scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
+ # 得到(batch_size, num_beams), (batch_size, num_beams)
+ next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
+
+ indices = torch.arange(batch_size, dtype=torch.long).to(device)
+ indices = indices.repeat_interleave(num_beams)
+ decoder.reorder_past(indices, past)
+
+ tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
+ # 记录生成好的token (batch_size', cur_len)
+ token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
+ dones = [False] * batch_size
+ tokens = next_tokens.view(-1, 1)
+
+ beam_scores = next_scores.view(-1) # batch_size * num_beams
+
+ # 用来记录已经生成好的token的长度
+ cur_len = token_ids.size(1)
+
+ hypos = [
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
+ ]
+ # 0,num_beams, 2*num_beams, ...
+ batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
+
+ while cur_len < max_length:
+ # scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past
+ scores, past = decoder.decode(tokens, past)
+ if repetition_penalty != 1.0:
+ token_scores = scores.gather(dim=1, index=token_ids)
+ lt_zero_mask = token_scores.lt(0).float()
+ ge_zero_mask = lt_zero_mask.eq(0).float()
+ token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
+ scores.scatter_(dim=1, index=token_ids, src=token_scores)
+
+ if do_sample:
+ if temperature > 0 and temperature != 1:
+ scores = scores / temperature
+
+ # 多召回一个防止eos
+ scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
+ probs = F.softmax(scores, dim=-1)
+
+ # 保证至少有一个不是eos的值
+ _tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1)
+
+ logits = probs.log()
+ # 防止全是这个beam的被选中了,且需要考虑eos被选择的情况
+ _scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1)
+ _scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1)
+ # 从这里面再选择top的2*num_beam个
+ _scores = _scores.view(batch_size, num_beams * (num_beams + 1))
+ next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
+ _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
+ next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams)
+ from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams)
+ else:
+ scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
+ _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
+ _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
+ next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
+ from_which_beam = ids // vocab_size # (batch_size, 2*num_beams)
+ next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
+
+ # 接下来需要组装下一个batch的结果。
+ # 需要选定哪些留下来
+ next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
+ next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
+ from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
+
+ not_eos_mask = next_tokens.ne(eos_token_id) # 为1的地方不是eos
+ keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留
+ keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的
+
+ _next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
+ _from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam
+ _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
+ beam_scores = _next_scores.view(-1)
+
+ # 更改past状态, 重组token_ids
+ reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
+ decoder.reorder_past(reorder_inds, past)
+
+ flag = True
+ if cur_len + 1 == max_length:
+ eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
+ eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice
+ eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的
+ else:
+ # 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了
+ effective_eos_mask = next_tokens[:, :num_beams].eq(eos_token_id) # batch_size x num_beams
+ if effective_eos_mask.sum().gt(0):
+ eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
+ # 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams
+ eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
+ eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos
+ else:
+ flag = False
+ if flag:
+ for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
+ eos_beam_idx.tolist()):
+ if not dones[batch_idx]:
+ score = next_scores[batch_idx, beam_ind].item()
+ hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
+
+ # 重新组织token_ids的状态
+ tokens = _next_tokens
+ token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)
+
+ for batch_idx in range(batch_size):
+ dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
+
+ cur_len += 1
+
+ if all(dones):
+ break
+
+ # select the best hypotheses
+ tgt_len = token_ids.new(batch_size)
+ best = []
+
+ for i, hypotheses in enumerate(hypos):
+ best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
+ tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol
+ best.append(best_hyp)
+
+ # generate target batch
+ decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
+ for i, hypo in enumerate(best):
+ decoded[i, :tgt_len[i] - 1] = hypo
+ if eos_token_id is not None:
+ decoded[i, tgt_len[i] - 1] = eos_token_id
+
+ return decoded
+
+
+class BeamHypotheses(object):
+ def __init__(self, num_beams, max_length, length_penalty, early_stopping):
+ """
+ Initialize n-best list of hypotheses.
+ """
+ self.max_length = max_length - 1 # ignoring bos_token
+ self.length_penalty = length_penalty
+ self.early_stopping = early_stopping
+ self.num_beams = num_beams
+ self.hyp = []
+ self.worst_score = 1e9
+
+ def __len__(self):
+ """
+ Number of hypotheses in the list.
+ """
+ return len(self.hyp)
+
+ def add(self, hyp, sum_logprobs):
+ """
+ Add a new hypothesis to the list.
+ """
+ score = sum_logprobs / len(hyp) ** self.length_penalty
+ if len(self) < self.num_beams or score > self.worst_score:
+ self.hyp.append((score, hyp))
+ if len(self) > self.num_beams:
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
+ del self.hyp[sorted_scores[0][1]]
+ self.worst_score = sorted_scores[1][0]
+ else:
+ self.worst_score = min(score, self.worst_score)
+
+ def is_done(self, best_sum_logprobs):
+ """
+ If there are enough hypotheses and that none of the hypotheses being generated
+ can become better than the worst one in the heap, then we are done with this sentence.
+ """
+ if len(self) < self.num_beams:
+ return False
+ elif self.early_stopping:
+ return True
+ else:
+ return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
+
+
+def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
+ """
+ 根据top_k, top_p的值,将不满足的值置为filter_value的值
+
+ :param torch.Tensor logits: bsz x vocab_size
+ :param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value
+ :param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式
+ :param float filter_value:
+ :param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值
+ :return:
+ """
+ if top_k > 0:
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ if min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ logits[indices_to_remove] = filter_value
+ return logits
diff --git a/fastNLP/modules/tokenizer/__init__.py b/fastNLP/modules/tokenizer/__init__.py
new file mode 100644
index 00000000..f3c4faae
--- /dev/null
+++ b/fastNLP/modules/tokenizer/__init__.py
@@ -0,0 +1,14 @@
+r"""
+
+"""
+__all__=[
+ 'BertTokenizer',
+
+ "GPT2Tokenizer",
+
+ "RobertaTokenizer"
+]
+
+from .bert_tokenizer import BertTokenizer
+from .gpt2_tokenizer import GPT2Tokenizer
+from .roberta_tokenizer import RobertaTokenizer
\ No newline at end of file
diff --git a/fastNLP/modules/tokenizer/bert_tokenizer.py b/fastNLP/modules/tokenizer/bert_tokenizer.py
new file mode 100644
index 00000000..7df6b52d
--- /dev/null
+++ b/fastNLP/modules/tokenizer/bert_tokenizer.py
@@ -0,0 +1,447 @@
+r"""
+
+"""
+
+__all__ = [
+ 'BertTokenizer'
+]
+
+import os
+import collections
+import unicodedata
+from ...core import logger
+from ..utils import _get_file_name_base_on_postfix
+from ...io.file_utils import _get_bert_dir
+
+VOCAB_NAME = 'vocab.txt'
+
+PRETRAINED_INIT_CONFIGURATION = {
+ "en": {"do_lower_case": False},
+ "en-base-uncased": {'do_lower_case': True},
+ 'en-base-cased': {'do_lower_case':False},
+ "en-large-cased-wwm": {"do_lower_case": False},
+ 'en-large-cased': {'do_lower_case':False},
+ 'en-large-uncased': {'do_lower_case':True},
+ 'en-large-uncased-wwm': {'do_lower_case':True},
+ 'cn': {'do_lower_case':True},
+ 'cn-base': {'do_lower_case': True},
+ 'cn-wwm-ext': {'do_lower_case': True},
+ 'multi-base-cased': {'do_lower_case': False},
+ 'multi-base-uncased': {'do_lower_case': True},
+}
+
+def _is_control(char):
+ r"""Checks whether `chars` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat.startswith("C"):
+ return True
+ return False
+
+
+def _is_punctuation(char):
+ r"""Checks whether `chars` is a punctuation character."""
+ cp = ord(char)
+ # We treat all non-letter/number ASCII as punctuation.
+ # Characters such as "^", "$", and "`" are not in the Unicode
+ # Punctuation class but we treat them as punctuation anyways, for
+ # consistency.
+ if (((cp >= 33) and (cp <= 47)) or ((cp >= 58) and (cp <= 64)) or
+ ((cp >= 91) and (cp <= 96)) or ((cp >= 123) and (cp <= 126))):
+ return True
+ cat = unicodedata.category(char)
+ if cat.startswith("P"):
+ return True
+ return False
+
+
+def _is_whitespace(char):
+ r"""Checks whether `chars` is a whitespace character."""
+ # \t, \n, and \r are technically contorl characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+def whitespace_tokenize(text):
+ r"""Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class BasicTokenizer(object):
+ r"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+ def __init__(self,
+ do_lower_case=True,
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+ r"""Constructs a BasicTokenizer.
+
+ Args:
+ do_lower_case: Whether to lower case the input.
+ """
+ self.do_lower_case = do_lower_case
+ self.never_split = never_split
+
+ def tokenize(self, text):
+ r"""Tokenizes a piece of text."""
+ text = self._clean_text(text)
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ text = self._tokenize_chinese_chars(text)
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if self.do_lower_case and token not in self.never_split:
+ token = token.lower()
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ r"""Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text):
+ r"""Splits punctuation on a piece of text."""
+ if text in self.never_split:
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ r"""Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ r"""Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (((cp >= 0x4E00) and (cp <= 0x9FFF)) or #
+ ((cp >= 0x3400) and (cp <= 0x4DBF)) or #
+ ((cp >= 0x20000) and (cp <= 0x2A6DF)) or #
+ ((cp >= 0x2A700) and (cp <= 0x2B73F)) or #
+ ((cp >= 0x2B740) and (cp <= 0x2B81F)) or #
+ ((cp >= 0x2B820) and (cp <= 0x2CEAF)) or
+ ((cp >= 0xF900) and (cp <= 0xFAFF)) or #
+ ((cp >= 0x2F800) and (cp <= 0x2FA1F))): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ r"""Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xfffd or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+def load_vocab(vocab_file):
+ r"""Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ index = 0
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ while True:
+ token = reader.readline()
+ if not token:
+ break
+ token = token.strip()
+ vocab[token] = index
+ index += 1
+ return vocab
+
+
+class WordpieceTokenizer(object):
+ r"""Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ r"""Tokenizes a piece of text into its word pieces.
+
+ This uses a greedy longest-match-first algorithm to perform tokenization
+ using the given vocabulary.
+
+ For example:
+ input = "unaffable"
+ output = ["un", "##aff", "##able"]
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through `BasicTokenizer`.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ if len(output_tokens) == 0: # 防止里面全是空格或者回车符号
+ return [self.unk_token]
+ return output_tokens
+
+
+class BertTokenizer(object):
+ r"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
+
+ def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+ r"""Constructs a BertTokenizer.
+
+ Args:
+ vocab_file: Path to a one-wordpiece-per-line vocabulary file
+ do_lower_case: Whether to lower case the input
+ Only has an effect when do_wordpiece_only=False
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
+ max_len: An artificial maximum length to truncate tokenized sequences to;
+ Effective maximum length is always the minimum of this
+ value (if specified) and the underlying BERT model's
+ sequence length.
+ never_split: List of tokens which will never be split during tokenization.
+ Only has an effect when do_wordpiece_only=False
+ """
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict(
+ [(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
+ never_split=never_split)
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+ self.max_len = max_len if max_len is not None else int(1e12)
+
+ @property
+ def unk_index(self):
+ return self.vocab['[UNK]']
+
+ @property
+ def pad_index(self):
+ return self.vocab['[PAD]']
+
+ @property
+ def cls_index(self):
+ return self.vocab['[CLS]']
+
+ @property
+ def sep_index(self):
+ return self.vocab['[SEP]']
+
+ def _reinit_on_new_vocab(self, vocab):
+ r"""
+ 在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质
+
+ :param vocab:
+ :return:
+ """
+ self.vocab = vocab
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+
+ def tokenize(self, text):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(text):
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
+ split_tokens.append(sub_token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ r"""Converts a sequence of tokens into ids using the vocab."""
+ ids = []
+ for token in tokens:
+ ids.append(self.vocab[token])
+ if len(ids) > self.max_len:
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum "
+ " sequence length for this BERT model ({} > {}). Running this"
+ " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
+ )
+ return ids
+
+ def convert_ids_to_tokens(self, ids):
+ r"""将token ids转换为一句话"""
+ tokens = []
+ for i in ids:
+ tokens.append(self.ids_to_tokens[i])
+ return self._convert_tokens_to_string(tokens)
+
+ def _convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def save_vocabulary(self, vocab_path):
+ r"""Save the tokenizer vocabulary to a directory or file."""
+ index = 0
+ if os.path.isdir(vocab_path):
+ vocab_file = os.path.join(vocab_path, VOCAB_NAME)
+ else:
+ vocab_file = vocab_path
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!".format(vocab_file))
+ index = token_index
+ writer.write(token + u'\n')
+ index += 1
+ return vocab_file
+
+ @classmethod
+ def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
+ r"""
+ 给定模型的名字或者路径,直接读取vocab.
+ """
+ model_dir = _get_bert_dir(model_dir_or_name)
+ pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt')
+ logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path))
+ max_len = 512
+ kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len)
+ # Instantiate tokenizer.
+ if 'do_lower_case' not in kwargs:
+ if model_dir_or_name in PRETRAINED_INIT_CONFIGURATION:
+ kwargs['do_lower_case'] = PRETRAINED_INIT_CONFIGURATION[model_dir_or_name]['do_lower_case']
+ else:
+ if 'case' in model_dir_or_name:
+ kwargs['do_lower_case'] = False
+ elif 'uncase' in model_dir_or_name:
+ kwargs['do_lower_case'] = True
+
+ tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs)
+ return tokenizer
+
+ def encode(self, text, add_special_tokens=True):
+ """
+ 给定text输入将数据encode为index的形式。
+
+ Example::
+
+ >>> from fastNLP.modules import BertTokenizer
+ >>> bert_tokenizer = BertTokenizer.from_pretrained('en')
+ >>> print(bert_tokenizer.encode('from'))
+ >>> print(bert_tokenizer.encode("This is a demo sentence"))
+ >>> print(bert_tokenizer.encode(["This", "is", 'a']))
+
+
+ :param List[str],str text: 输入的一条认为是一句话。
+ :param bool add_special_tokens: 是否保证句首和句尾是cls和sep。
+ :return:
+ """
+
+ word_pieces = []
+ if isinstance(text, str):
+ words = text.split()
+ elif isinstance(text, list):
+ words = text
+ else:
+ raise TypeError("Only support str or List[str]")
+ for word in words:
+ _words = self.basic_tokenizer._tokenize_chinese_chars(word).split()
+ tokens = []
+ for word in _words:
+ tokens.extend(self.wordpiece_tokenizer.tokenize(word))
+ word_piece_ids = self.convert_tokens_to_ids(tokens)
+ word_pieces.extend(word_piece_ids)
+ if add_special_tokens:
+ if word_pieces[0] != self.cls_index:
+ word_pieces.insert(0, self.cls_index)
+ if word_pieces[-1] != self.sep_index:
+ word_pieces.append(self.sep_index)
+ return word_pieces
diff --git a/fastNLP/modules/tokenizer/gpt2_tokenizer.py b/fastNLP/modules/tokenizer/gpt2_tokenizer.py
new file mode 100644
index 00000000..08675a23
--- /dev/null
+++ b/fastNLP/modules/tokenizer/gpt2_tokenizer.py
@@ -0,0 +1,758 @@
+r"""undocumented
+这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
+ 有用,也请引用一下他们。
+"""
+
+__all__ = [
+ 'GPT2Tokenizer'
+]
+
+from functools import lru_cache
+import json
+import regex as re
+import itertools
+
+
+from ...io.file_utils import _get_gpt2_dir
+from ...core import logger
+from ..utils import _get_file_name_base_on_postfix
+
+
+import os
+
+PRETRAINED_GPT2_MODEL_DIR = PRETRAINED_BERT_MODEL_DIR = {
+ 'en-small': 'gpt2-small.zip',
+ 'en-median': 'gpt2-medium.zip',
+ 'en': 'gpt2-medium.zip'
+}
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings.
+ We specifically avoids mapping to whitespace/control characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2 ** 8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2 ** 8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
+ "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
+ "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
+ "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
+ "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
+ },
+ "merges_file": {
+ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
+ "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
+ "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
+ "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
+ "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
+ },
+}
+
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "en-small": 1024,
+ 'en': 1024,
+ "en-medium": 1024,
+ "en-large": 1024,
+ "en-xl": 1024,
+ "en-distilgpt2": 1024,
+}
+
+PATTERN = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+
+def gpt2_tokenize(text, add_prefix_space=True):
+ """
+
+ :param str text:
+ :param bool add_prefix_space: 是否在句子前面加上space,如果加上才能保证与GPT2训练时一致
+ :return: []
+ """
+ if text is '':
+ return []
+ if add_prefix_space:
+ text = ' ' + text
+ tokens = []
+ for token in re.findall(PATTERN, text):
+ tokens.append(token)
+ return tokens
+
+
+class GPT2Tokenizer:
+ """
+ GPT-2 BPE tokenizer. Peculiarities:
+ - Byte-level Byte-Pair-Encoding
+ - Requires a space to start the input string => the encoding and tokenize methods should be called with the
+ ``add_prefix_space`` flag set to ``True``.
+ Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve
+ the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"`
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+
+ SPECIAL_TOKENS_ATTRIBUTES = [
+ "bos_token",
+ "eos_token",
+ "unk_token",
+ "pad_token",
+ "cls_token",
+ "mask_token",
+ "sep_token",
+ ]
+
+ padding_side = "right"
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ **kwargs
+ ):
+ self._bos_token = None
+ self._eos_token = None
+ self._unk_token = None
+ self._sep_token = None
+ self._pad_token = None
+ self._cls_token = None
+ self._mask_token = None
+ self._pad_token_type_id = 0
+
+ self.bos_token = bos_token
+ self.eos_token = eos_token
+ self.unk_token = unk_token
+
+ self.max_len = int(1e12)
+ self.padding_side = kwargs.pop("padding_side", self.padding_side)
+ self.added_tokens_encoder = {}
+ self.unique_added_tokens_encoder = set()
+ self.added_tokens_decoder = {}
+ # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
+ self.init_inputs = ()
+ self.init_kwargs = {}
+
+ for key, value in kwargs.items():
+ if key in self.SPECIAL_TOKENS_ATTRIBUTES:
+ if key == "additional_special_tokens":
+ assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
+ else:
+ assert isinstance(value, str)
+ setattr(self, key, value)
+
+ self.max_len_single_sentence = (
+ self.max_len
+ ) # no default special tokens - you can update this value if you add special tokens
+ self.max_len_sentences_pair = (
+ self.max_len
+ ) # no default special tokens - you can update this value if you add special tokens
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+
+ def _reinit_on_new_vocab(self, vocab):
+ self.encoder = {k:v for k,v in vocab.items()}
+ self.decoder = {v:k for k,v in vocab.items()}
+ self.cache = {}
+
+ @property
+ def bos_token(self):
+ """ Beginning of sentence token (string). Log an error if used while not having been set. """
+ if self._bos_token is None:
+ logger.error("Using bos_token, but it is not set yet.")
+ return self._bos_token
+
+ @property
+ def eos_token(self):
+ """ End of sentence token (string). Log an error if used while not having been set. """
+ if self._eos_token is None:
+ logger.error("Using eos_token, but it is not set yet.")
+ return self._eos_token
+
+ @property
+ def unk_token(self):
+ """ Unknown token (string). Log an error if used while not having been set. """
+ if self._unk_token is None:
+ logger.error("Using unk_token, but it is not set yet.")
+ return self._unk_token
+
+ @property
+ def pad_token(self):
+ """ Padding token (string). Log an error if used while not having been set. """
+ if self._pad_token is None:
+ logger.error("Using pad_token, but it is not set yet.")
+ return self._pad_token
+
+ @property
+ def cls_token(self):
+ """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
+ if self._cls_token is None:
+ logger.error("Using cls_token, but it is not set yet.")
+ return self._cls_token
+
+ @property
+ def sep_token(self):
+ if self._sep_token is None:
+ logger.error("Using sep_token, but it is not set yet.")
+ return self._sep_token
+
+ @property
+ def mask_token(self):
+ """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
+ if self._mask_token is None:
+ logger.error("Using mask_token, but it is not set yet.")
+ return self._mask_token
+
+ @bos_token.setter
+ def bos_token(self, value):
+ self._bos_token = value
+
+ @eos_token.setter
+ def eos_token(self, value):
+ self._eos_token = value
+
+ @unk_token.setter
+ def unk_token(self, value):
+ self._unk_token = value
+
+ @pad_token.setter
+ def pad_token(self, value):
+ self._pad_token = value
+
+ @cls_token.setter
+ def cls_token(self, value):
+ self._cls_token = value
+
+ @sep_token.setter
+ def sep_token(self, value):
+ self._sep_token = value
+
+ @mask_token.setter
+ def mask_token(self, value):
+ self._mask_token = value
+
+ @property
+ def bos_index(self):
+ """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.bos_token)
+
+ @property
+ def sep_index(self):
+ return self.convert_tokens_to_ids(self.sep_token)
+
+ @property
+ def eos_index(self):
+ """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.eos_token)
+
+ @property
+ def unk_index(self):
+ """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.unk_token)
+
+ @property
+ def pad_index(self):
+ """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.pad_token)
+
+ @property
+ def pad_token_type_id(self):
+ """ Id of the padding token type in the vocabulary."""
+ return self._pad_token_type_id
+
+ @property
+ def cls_index(self):
+ """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.cls_token)
+
+ @property
+ def mask_index(self):
+ """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.mask_token)
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def bpe(self, token):
+ # 如果token没有找到,会被拆分成字母返回
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word) # 如果word是abcd,则((a,b), (b,c), (c, d), (e,f))
+
+ if not pairs:
+ return token
+
+ while True:
+ # 首先找到最常的pair
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j]) #最先找的
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text, add_prefix_space=False):
+ """ Tokenize a string.
+ Args:
+ - add_prefix_space (boolean, default False):
+ Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
+ """
+ bpe_tokens = []
+ for token in gpt2_tokenize(text, add_prefix_space=add_prefix_space):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory):
+ """Save the tokenizer vocabulary and merge files to a directory."""
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+ merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, ensure_ascii=False))
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ "Saving vocabulary to {}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!".format(merge_file)
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ @classmethod
+ def from_pretrained(cls, model_dir_or_name):
+ r"""
+ """
+ return cls._from_pretrained(model_dir_or_name)
+
+ # 将它修改一定传入文件夹
+ @classmethod
+ def _from_pretrained(cls, model_dir_or_name):
+ """
+
+ :param str model_dir_or_name: 目录或者缩写名
+ :param init_inputs:
+ :param kwargs:
+ :return:
+ """
+ # 它需要两个文件,第一个是vocab.json,第二个是merge_file?
+ model_dir = _get_gpt2_dir(model_dir_or_name)
+ # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin
+
+ tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json')
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
+ init_kwargs = json.load(tokenizer_config_handle)
+ if 'max_len' not in init_kwargs:
+ init_kwargs['max_len'] = 1024
+ # Set max length if needed
+ if model_dir_or_name in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
+ # if we're using a pretrained model, ensure the tokenizer
+ # wont index sequences longer than the number of positional embeddings
+ max_len = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name]
+ if max_len is not None and isinstance(max_len, (int, float)):
+ init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
+
+ # 将vocab, merge加入到init_kwargs中
+ init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, 'vocab.json')
+ init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, 'merges.txt')
+
+ init_inputs = init_kwargs.pop("init_inputs", ())
+ # Instantiate tokenizer.
+ try:
+ tokenizer = cls(*init_inputs, **init_kwargs)
+ except OSError:
+ OSError(
+ "Unable to load vocabulary from file. "
+ "Please check that the provided vocabulary is accessible and not corrupted."
+ )
+
+ return tokenizer
+
+ def __len__(self):
+ """ Size of the full vocabulary with the added tokens """
+ return self.vocab_size + len(self.added_tokens_encoder)
+
+ def tokenize(self, text, add_prefix_space=True):
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
+ Split in words for word-based vocabulary or sub-words for sub-word-based
+ vocabularies (BPE/SentencePieces/WordPieces).
+
+ Take care of added tokens.
+ Args:
+ - text: The sequence to be encoded.
+ - add_prefix_space (boolean, default True):
+ Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
+ """
+ all_special_tokens = self.all_special_tokens
+
+ def lowercase_text(t):
+ # convert non-special tokens to lowercase
+ escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
+ pattern = r'(' + r'|'.join(escaped_special_toks) + r')|' + \
+ r'(.+?)'
+ return re.sub(
+ pattern,
+ lambda m: m.groups()[0] or m.groups()[1].lower(),
+ t)
+
+ if self.init_kwargs.get('do_lower_case', False):
+ text = lowercase_text(text)
+
+ def split_on_token(tok, text):
+ result = []
+ split_text = text.split(tok)
+ for i, sub_text in enumerate(split_text):
+ sub_text = sub_text.strip()
+ if i == 0 and not sub_text:
+ result += [tok]
+ elif i == len(split_text) - 1:
+ if sub_text:
+ result += [sub_text]
+ else:
+ pass
+ else:
+ if sub_text:
+ result += [sub_text]
+ result += [tok]
+ return result
+
+ def split_on_tokens(tok_list, text):
+ if not text.strip():
+ return []
+ if not tok_list:
+ return self._tokenize(text, add_prefix_space=add_prefix_space)
+
+ tokenized_text = []
+ text_list = [text]
+ for tok in tok_list:
+ tokenized_text = []
+ for sub_text in text_list:
+ if sub_text not in self.added_tokens_encoder \
+ and sub_text not in all_special_tokens:
+ tokenized_text += split_on_token(tok, sub_text)
+ else:
+ tokenized_text += [sub_text]
+ text_list = tokenized_text
+
+ return list(itertools.chain.from_iterable((self._tokenize(token, add_prefix_space=add_prefix_space) if token not \
+ in self.added_tokens_encoder and token not in all_special_tokens \
+ else [token] for token in tokenized_text)))
+
+ added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
+ tokenized_text = split_on_tokens(added_tokens, text)
+ return tokenized_text
+
+ def convert_tokens_to_ids(self, tokens):
+ """ Converts a single token, or a sequence of tokens, (str) in a single integer id
+ (resp. a sequence of ids), using the vocabulary.
+ """
+ if tokens is None:
+ return None
+
+ if isinstance(tokens, str):
+ return self._convert_token_to_id_with_added_voc(tokens)
+
+ ids = []
+ for token in tokens:
+ ids.append(self._convert_token_to_id_with_added_voc(token))
+ return ids
+
+ def _convert_token_to_id_with_added_voc(self, token):
+ if token is None:
+ return None
+
+ if token in self.added_tokens_encoder:
+ return self.added_tokens_encoder[token]
+ return self._convert_token_to_id(token)
+
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
+ """ Converts a single index or a sequence of indices (integers) in a token "
+ (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
+
+ Args:
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
+ """
+ if isinstance(ids, int):
+ return self._convert_id_to_token(ids)
+ tokens = []
+ for index in ids:
+ index = int(index)
+ if skip_special_tokens and index in self.all_special_ids:
+ continue
+ tokens.append(self._convert_id_to_token(index))
+ return tokens
+
+ def convert_id_to_tokens(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
+ """
+ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
+ with options to remove special tokens and clean up tokenization spaces.
+ Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
+
+ Args:
+ token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
+ skip_special_tokens: if set to True, will replace special tokens.
+ clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
+ """
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
+
+ # To avoid mixing byte-level and unicode for byte-level BPT
+ # we need to build string separatly for added tokens and byte-level tokens
+ # cf. https://github.com/huggingface/transformers/issues/1133
+ sub_texts = []
+ current_sub_text = []
+ for token in filtered_tokens:
+ if skip_special_tokens and token in self.all_special_ids:
+ continue
+ if token in self.added_tokens_encoder:
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ current_sub_text = []
+ sub_texts.append(token)
+ else:
+ current_sub_text.append(token)
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ text = " ".join(sub_texts)
+
+ if clean_up_tokenization_spaces:
+ clean_text = self.clean_up_tokenization(text)
+ return clean_text
+ else:
+ return text
+
+ @property
+ def special_tokens_map(self):
+ """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
+ values ('', ''...)
+ """
+ set_attr = {}
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+ attr_value = getattr(self, "_" + attr)
+ if attr_value:
+ set_attr[attr] = attr_value
+ return set_attr
+
+ @property
+ def all_special_tokens(self):
+ """ List all the special tokens ('', ''...) mapped to class attributes
+ (cls_token, unk_token...).
+ """
+ all_toks = []
+ set_attr = self.special_tokens_map
+ for attr_value in set_attr.values():
+ all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
+ all_toks = list(set(all_toks))
+ return all_toks
+
+ @property
+ def all_special_ids(self):
+ """ List the vocabulary indices of the special tokens ('', ''...) mapped to
+ class attributes (cls_token, unk_token...).
+ """
+ all_toks = self.all_special_tokens
+ all_ids = self.convert_tokens_to_ids(all_toks)
+ return all_ids
+
+ @staticmethod
+ def clean_up_tokenization(out_string):
+ """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
+ """
+ out_string = (
+ out_string.replace(" .", ".")
+ .replace(" ?", "?")
+ .replace(" !", "!")
+ .replace(" ,", ",")
+ .replace(" ' ", "'")
+ .replace(" n't", "n't")
+ .replace(" 'm", "'m")
+ .replace(" do not", " don't")
+ .replace(" 's", "'s")
+ .replace(" 've", "'ve")
+ .replace(" 're", "'re")
+ )
+ return out_string
+
+ def encode(self, text, add_special_tokens=False, add_prefix_space=True):
+ """
+ 给定text输入将数据encode为index的形式。
+
+ Example::
+
+ >>> from fastNLP.modules import GPT2Tokenizer
+ >>> gpt2_tokenizer = GPT2Tokenizer.from_pretrained('en')
+ >>> print(gpt2_tokenizer.encode('from'))
+ >>> print(gpt2_tokenizer.encode("This is a demo sentence"))
+ >>> print(gpt2_tokenizer.encode(["This", "is", 'a']))
+
+
+ :param List[str],str text: 输入的一条认为是一句话。
+ :param bool add_special_tokens: 是否保证句首和句尾是cls和sep。GPT2没有cls和sep这一说
+ :return:
+ """
+ if isinstance(text, str):
+ words = text.split()
+ elif isinstance(text, list):
+ words = text
+ else:
+ raise TypeError("Only support str or List[str]")
+
+ word_pieces = []
+ for word in words:
+ tokens = self.tokenize(word, add_prefix_space=add_prefix_space)
+ word_piece_ids = self.convert_tokens_to_ids(tokens)
+ word_pieces.extend(word_piece_ids)
+ if add_special_tokens:
+ if self._cls_token is not None and word_pieces[0] != self.cls_index:
+ word_pieces.insert(0, self.cls_index)
+ if self._sep_token is not None and word_pieces[-1] != self.sep_index:
+ word_pieces.append(self.eos_index)
+ return word_pieces
+
+ def get_used_merge_pair_vocab(self, token):
+ # 如果token没有找到,会被拆分成字母返回 TODO need comment
+ used_pairs = {}
+ word = tuple(token)
+ pairs = get_pairs(word) # 如果word是abcd,则((a,b), (b,c), (c, d), (e,f))
+
+ if not pairs:
+ return token, used_pairs
+
+ while True:
+ # 首先找到最常的pair
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ used_pairs[bigram] = self.bpe_ranks[bigram]
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j]) #最先找的
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ return word, used_pairs
\ No newline at end of file
diff --git a/fastNLP/modules/tokenizer/roberta_tokenizer.py b/fastNLP/modules/tokenizer/roberta_tokenizer.py
new file mode 100644
index 00000000..ee2e5e97
--- /dev/null
+++ b/fastNLP/modules/tokenizer/roberta_tokenizer.py
@@ -0,0 +1,102 @@
+r"""
+
+"""
+
+__all__ = [
+ "RobertaTokenizer"
+]
+
+import json
+from .gpt2_tokenizer import GPT2Tokenizer
+from ..utils import _get_file_name_base_on_postfix
+from ...io.file_utils import _get_roberta_dir
+
+PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = {
+ "roberta-base": 512,
+ "roberta-large": 512,
+ "roberta-large-mnli": 512,
+ "distilroberta-base": 512,
+ "roberta-base-openai-detector": 512,
+ "roberta-large-openai-detector": 512,
+}
+
+
+class RobertaTokenizer(GPT2Tokenizer):
+
+ vocab_files_names = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+ }
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+ self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
+ self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
+
+ @classmethod
+ def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
+ """
+
+ :param str model_dir_or_name: 目录或者缩写名
+ :param kwargs:
+ :return:
+ """
+ # 它需要两个文件,第一个是vocab.json,第二个是merge_file?
+ model_dir = _get_roberta_dir(model_dir_or_name)
+ # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin
+
+ tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json')
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
+ init_kwargs = json.load(tokenizer_config_handle)
+ # Set max length if needed
+ if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES:
+ # if we're using a pretrained model, ensure the tokenizer
+ # wont index sequences longer than the number of positional embeddings
+ max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name]
+ if max_len is not None and isinstance(max_len, (int, float)):
+ init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
+
+ # 将vocab, merge加入到init_kwargs中
+ if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表
+ init_kwargs['vocab_file'] = kwargs['vocab_file']
+ else:
+ init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['vocab_file'])
+ init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['merges_file'])
+
+ init_inputs = init_kwargs.pop("init_inputs", ())
+ # Instantiate tokenizer.
+ try:
+ tokenizer = cls(*init_inputs, **init_kwargs)
+ except OSError:
+ OSError(
+ "Unable to load vocabulary from file. "
+ "Please check that the provided vocabulary is accessible and not corrupted."
+ )
+
+ return tokenizer
+
diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py
index 79e2a7de..061cd8ae 100644
--- a/fastNLP/modules/utils.py
+++ b/fastNLP/modules/utils.py
@@ -144,18 +144,8 @@ def _get_file_name_base_on_postfix(dir_path, postfix):
"""
files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path))))
if len(files) == 0:
- raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}")
+ raise FileNotFoundError(f"There is no file endswith {postfix} file in {dir_path}")
elif len(files) > 1:
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}")
return os.path.join(dir_path, files[0])
-
-def create_position_ids_from_input_ids(input_ids, padding_idx=0):
- r""" Replace non-padding symbols with their position numbers. Position numbers begin at
- padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
- `utils.make_positions`.
- """
- # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
- mask = input_ids.ne(padding_idx).int()
- incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask
- return incremental_indicies.long() + padding_idx
diff --git a/reproduction/Summarization/Baseline/train_origin.py b/reproduction/Summarization/Baseline/train_origin.py
index 36a2b716..7c4d2f12 100644
--- a/reproduction/Summarization/Baseline/train_origin.py
+++ b/reproduction/Summarization/Baseline/train_origin.py
@@ -687,16 +687,16 @@ def main():
if hps.mode == 'train':
trainset = dataInfo.datasets["train"]
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT)
- train_batch = DataSetIter(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler)
+ train_batch = DataSetIter(dataset=trainset, batch_size=hps.batch_size, sampler=train_sampler)
validset = dataInfo.datasets["valid"]
validset.set_input("text", "summary")
- valid_batch = DataSetIter(batch_size=hps.batch_size, dataset=validset)
+ valid_batch = DataSetIter(dataset=validset, batch_size=hps.batch_size)
setup_training(model, train_batch, valid_batch, hps)
elif hps.mode == 'test':
logger.info("[INFO] Decoding...")
testset = dataInfo.datasets["test"]
testset.set_input("text", "summary")
- test_batch = DataSetIter(batch_size=hps.batch_size, dataset=testset)
+ test_batch = DataSetIter(dataset=testset, batch_size=hps.batch_size)
run_test(model, test_batch, hps, limited=hps.limited)
else:
logger.error("The 'mode' flag must be one of train/eval/test")
diff --git a/reproduction/multi-criteria-cws/main.py b/reproduction/multi-criteria-cws/main.py
index 049a1974..8ee1f81e 100644
--- a/reproduction/multi-criteria-cws/main.py
+++ b/reproduction/multi-criteria-cws/main.py
@@ -406,18 +406,8 @@ if not options.test:
logger.info("Number training instances: {}".format(len(train_set)))
logger.info("Number dev instances: {}".format(len(dev_set)))
- train_batch = DataSetIter(
- batch_size=options.batch_size,
- dataset=train_set,
- sampler=train_sampler,
- num_workers=4,
- )
- dev_batch = DataSetIter(
- batch_size=options.batch_size,
- dataset=dev_set,
- sampler=dev_sampler,
- num_workers=4,
- )
+ train_batch = DataSetIter(dataset=train_set, batch_size=options.batch_size, sampler=train_sampler, num_workers=4)
+ dev_batch = DataSetIter(dataset=dev_set, batch_size=options.batch_size, sampler=dev_sampler, num_workers=4)
best_f1 = 0.0
for epoch in range(int(options.num_epochs)):
diff --git a/test/core/test_batch.py b/test/core/test_batch.py
index 18cbf59d..6a340d36 100644
--- a/test/core/test_batch.py
+++ b/test/core/test_batch.py
@@ -279,7 +279,7 @@ class TestCase1(unittest.TestCase):
data.add_collate_fn(concat_collate_fn)
- for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
+ for batch_x, batch_y in DataSetIter(data, batch_size=2, sampler=SequentialSampler()):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3, 0],
@@ -302,7 +302,7 @@ class TestCase1(unittest.TestCase):
return b_x, b_y
data.delete_collate_fn() # 删除之前的collate_fn
data.add_collate_fn(ConCollateFn(max_len=3))
- for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
+ for batch_x, batch_y in DataSetIter(data, batch_size=2, sampler=SequentialSampler()):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3],
@@ -362,10 +362,9 @@ class TestCase1(unittest.TestCase):
batch_sampler = BatchSampler(ds)
- data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False,
- num_workers=0, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None, collate_fn=None,
- batch_sampler=batch_sampler)
+ data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0,
+ pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
+ batch_sampler=batch_sampler)
num_samples = [len(ds)//2, len(ds)-len(ds)//2]
for idx, (batch_x, batch_y) in enumerate(data_iter):
self.assertEqual(num_samples[idx], len(batch_x['1']))
diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py
index d048191f..03f24ad1 100644
--- a/test/core/test_dataset.py
+++ b/test/core/test_dataset.py
@@ -264,7 +264,6 @@ class TestDataSetMethods(unittest.TestCase):
self.assertEqual(ans.content, [[5, 6]] * 10)
def test_add_null(self):
- # TODO test failed because 'fastNLP\core\field.py:143: RuntimeError'
ds = DataSet()
with self.assertRaises(RuntimeError) as RE:
ds.add_field('test', [])
diff --git a/test/data_for_tests/embedding/small_gpt2/config.json b/test/data_for_tests/embedding/small_gpt2/config.json
new file mode 100644
index 00000000..b2f61bdc
--- /dev/null
+++ b/test/data_for_tests/embedding/small_gpt2/config.json
@@ -0,0 +1 @@
+{"architectures": ["GPT2LMHeadModel"], "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, "n_ctx": 20, "n_embd": 16, "n_head": 4, "n_layer": 2, "n_positions": 20, "vocab_size": 64}
\ No newline at end of file
diff --git a/test/data_for_tests/embedding/small_gpt2/merges.txt b/test/data_for_tests/embedding/small_gpt2/merges.txt
new file mode 100644
index 00000000..5e4f2b9b
--- /dev/null
+++ b/test/data_for_tests/embedding/small_gpt2/merges.txt
@@ -0,0 +1,39 @@
+#version: small
+a b
+c e
+e l
+e m
+e n
+en ce
+en t
+h e
+he r
+i s
+o c
+o d
+o t
+ot her
+x t
+Ġ T
+Ġ a
+Ġ d
+Ġ is
+Ġ m
+Ġ s
+Ġ t
+Ġ v
+ĠT h
+ĠTh is
+Ġa n
+Ġan other
+Ġd em
+Ġdem o
+Ġm od
+Ġmod el
+Ġs ent
+Ġsent ence
+Ġt e
+Ġt h
+Ġte xt
+Ġth is
+Ġv oc
diff --git a/test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin b/test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin
new file mode 100644
index 00000000..ec2f48d7
Binary files /dev/null and b/test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin differ
diff --git a/test/data_for_tests/embedding/small_gpt2/vocab.json b/test/data_for_tests/embedding/small_gpt2/vocab.json
new file mode 100644
index 00000000..8f9feeda
--- /dev/null
+++ b/test/data_for_tests/embedding/small_gpt2/vocab.json
@@ -0,0 +1 @@
+{"\u0120This": 0, "\u0120is": 1, "\u0120a": 2, "\u0120demo": 3, "\u0120sentence": 4, "\u0120another": 5, "\u0120this": 6, "\u0120text": 7, "a": 8, "\u0120model": 9, "\u0120voc": 10, "ab": 11, "<|endoftext|>": 12, "A": 13, "B": 14, "C": 15, "D": 16, "E": 17, "F": 18, "G": 19, "H": 20, "I": 21, "J": 22, "K": 23, "L": 24, "M": 25, "N": 26, "O": 27, "P": 28, "Q": 29, "R": 30, "S": 31, "T": 32, "U": 33, "V": 34, "W": 35, "X": 36, "Y": 37, "Z": 38, "b": 39, "c": 40, "d": 41, "e": 42, "f": 43, "g": 44, "h": 45, "i": 46, "j": 47, "k": 48, "l": 49, "m": 50, "n": 51, "o": 52, "p": 53, "q": 54, "r": 55, "s": 56, "t": 57, "u": 58, "v": 59, "w": 60, "x": 61, "y": 62, "z": 63}
\ No newline at end of file
diff --git a/test/data_for_tests/embedding/small_roberta/config.json b/test/data_for_tests/embedding/small_roberta/config.json
new file mode 100644
index 00000000..4814927b
--- /dev/null
+++ b/test/data_for_tests/embedding/small_roberta/config.json
@@ -0,0 +1 @@
+{"architectures": ["RobertaForMaskedLM"], "attention_probs_dropout_prob": 0.1, "finetuning_task": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 16, "initializer_range": 0.02, "intermediate_size": 20, "layer_norm_eps": 1e-05, "max_position_embeddings": 20, "num_attention_heads": 4, "num_hidden_layers": 2, "num_labels": 2, "output_attentions": false, "output_hidden_states": false, "torchscript": false, "type_vocab_size": 1, "vocab_size": 68}
\ No newline at end of file
diff --git a/test/data_for_tests/embedding/small_roberta/merges.txt b/test/data_for_tests/embedding/small_roberta/merges.txt
new file mode 100644
index 00000000..2af8d178
--- /dev/null
+++ b/test/data_for_tests/embedding/small_roberta/merges.txt
@@ -0,0 +1,39 @@
+#version: tiny
+a b
+c e
+e l
+e m
+e n
+en ce
+en t
+h e
+he r
+i s
+o c
+o d
+o t
+ot her
+x t
+Ġ T
+Ġ a
+Ġ d
+Ġ is
+Ġ m
+Ġ s
+Ġ t
+Ġ v
+ĠT h
+ĠTh is
+Ġa n
+Ġan other
+Ġd em
+Ġdem o
+Ġm od
+Ġmod el
+Ġs ent
+Ġsent ence
+Ġt e
+Ġt h
+Ġte xt
+Ġth is
+Ġv oc
diff --git a/test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin b/test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin
new file mode 100644
index 00000000..73282346
Binary files /dev/null and b/test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin differ
diff --git a/test/data_for_tests/embedding/small_roberta/vocab.json b/test/data_for_tests/embedding/small_roberta/vocab.json
new file mode 100644
index 00000000..376b658f
--- /dev/null
+++ b/test/data_for_tests/embedding/small_roberta/vocab.json
@@ -0,0 +1 @@
+{"": 0, "": 1, "": 2, "": 3, "": 4, "A": 5, "B": 6, "C": 7, "D": 8, "E": 9, "F": 10, "G": 11, "H": 12, "I": 13, "J": 14, "K": 15, "L": 16, "M": 17, "N": 18, "O": 19, "P": 20, "Q": 21, "R": 22, "S": 23, "T": 24, "U": 25, "V": 26, "W": 27, "X": 28, "Y": 29, "Z": 30, "a": 31, "b": 32, "c": 33, "d": 34, "e": 35, "f": 36, "g": 37, "h": 38, "i": 39, "j": 40, "k": 41, "l": 42, "m": 43, "n": 44, "o": 45, "p": 46, "q": 47, "r": 48, "s": 49, "t": 50, "u": 51, "v": 52, "w": 53, "x": 54, "y": 55, "z": 56, "\u0120This": 57, "\u0120is": 58, "\u0120a": 59, "\u0120demo": 60, "\u0120sentence": 61, "\u0120another": 62, "\u0120this": 63, "\u0120text": 64, "\u0120model": 65, "\u0120voc": 66, "ab": 67}
\ No newline at end of file
diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py
index fe4702ab..1593c53f 100644
--- a/test/embeddings/test_bert_embedding.py
+++ b/test/embeddings/test_bert_embedding.py
@@ -3,6 +3,8 @@ from fastNLP import Vocabulary
from fastNLP.embeddings import BertEmbedding, BertWordPieceEncoder
import torch
import os
+from fastNLP import DataSet
+
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
@@ -45,12 +47,83 @@ class TestBertEmbedding(unittest.TestCase):
result = embed(words)
self.assertEqual(result.size(), (1, 4, 16))
+ # 自动截断而不报错
+ embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1,
+ only_use_pretrain_bpe=True, auto_truncate=True)
+ words = torch.LongTensor([[2, 3, 4, 1]*10,
+ [2, 3]+[0]*38])
+ result = embed(words)
+ self.assertEqual(result.size(), (2, 40, 16))
+
+ def test_bert_embedding_2(self):
+ # 测试only_use_pretrain_vocab与truncate_embed是否正常工作
+ with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f:
+ num_word = len(f.readlines())
+ Embedding = BertEmbedding
+ vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split())
+ embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
+ only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1)
+ embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS]
+ self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab))
+
+ embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
+ only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1)
+ embed_bpe_vocab_size = num_word # 排除NotInBERT
+ self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab))
+
+ embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
+ only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1)
+ embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS]
+ self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab))
+
+ embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
+ only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1)
+ embed_bpe_vocab_size = num_word+1 # 新增##a
+ self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab))
+
+ # 测试各种情况下以下tensor的值是相等的
+ embed1.eval()
+ embed2.eval()
+ embed3.eval()
+ embed4.eval()
+ tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]])
+ t1 = embed1(tensor)
+ t2 = embed2(tensor)
+ t3 = embed3(tensor)
+ t4 = embed4(tensor)
+
+ self.assertEqual((t1-t2).sum(), 0)
+ self.assertEqual((t1-t3).sum(), 0)
+ self.assertEqual((t1-t4).sum(), 0)
+
class TestBertWordPieceEncoder(unittest.TestCase):
def test_bert_word_piece_encoder(self):
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
- from fastNLP import DataSet
ds = DataSet({'words': ["this is a test . [SEP]".split()]})
embed.index_datasets(ds, field_name='words')
self.assertTrue(ds.has_field('word_pieces'))
result = embed(torch.LongTensor([[1,2,3,4]]))
+
+ def test_bert_embed_eq_bert_piece_encoder(self):
+ ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]})
+ encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert')
+ encoder.eval()
+ encoder.index_datasets(ds, field_name='words')
+ word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
+ word_pieces_res = encoder(word_pieces)
+
+ vocab = Vocabulary()
+ vocab.from_dataset(ds, field_name='words')
+ vocab.index_dataset(ds, field_name='words', new_field_name='words')
+ ds.set_input('words')
+ words = torch.LongTensor(ds['words'].get([0, 1]))
+ embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
+ pool_method='first', include_cls_sep=True, pooled_cls=False)
+ embed.eval()
+ words_res = embed(words)
+
+ # 检查word piece什么的是正常work的
+ self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0)
+ self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0)
+ self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0)
\ No newline at end of file
diff --git a/test/embeddings/test_gpt2_embedding.py b/test/embeddings/test_gpt2_embedding.py
new file mode 100644
index 00000000..01e00410
--- /dev/null
+++ b/test/embeddings/test_gpt2_embedding.py
@@ -0,0 +1,268 @@
+
+import unittest
+import torch
+import os
+
+from fastNLP.modules.tokenizer.gpt2_tokenizer import GPT2Tokenizer
+from fastNLP.embeddings import GPT2WordPieceEncoder, GPT2Embedding
+from fastNLP import DataSet, Vocabulary
+
+
+class TestGPT2Embedding(unittest.TestCase):
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_download(self):
+ vocab = Vocabulary().add_word_lst("This is a test .".split())
+ embed = GPT2Embedding(vocab, model_dir_or_name='en')
+ words = torch.LongTensor([[2, 3, 4, 0]])
+ print(embed(words).size())
+
+ for pool_method in ['first', 'last', 'max', 'avg']:
+ embed = GPT2Embedding(vocab, model_dir_or_name='en', pool_method=pool_method)
+ print(embed(words).size())
+
+ def test_gpt2_embedding(self):
+ weight_path = 'test/data_for_tests/embedding/small_gpt2'
+ vocab = Vocabulary().add_word_lst("this is a texta sentence".split())
+ embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1)
+ requires_grad = embed.requires_grad
+ embed.requires_grad = not requires_grad
+ embed.train()
+ words = torch.LongTensor([[2, 3, 4, 0]])
+ result = embed(words)
+ self.assertEqual(result.size(), (1, 4, 16))
+
+ embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1,
+ only_use_pretrain_bpe=False, language_model=True)
+ embed.eval()
+ words = torch.LongTensor([[2, 3, 4, 0]])
+ result = embed(words)
+ self.assertEqual(result.size(), (1, 4, 16))
+ embed.get_lm_loss()
+
+ vocab.add_word("NotInGpt2")
+ embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1,
+ only_use_pretrain_bpe=False, auto_truncate=True, min_freq=1)
+ words = torch.LongTensor([[2, 3, 4, 0]*20])
+ result = embed(words)
+ self.assertEqual(result.size(), (1, 80, 16))
+
+ def test_gpt2_ebembedding_2(self):
+ # 测试only_use_pretrain_vocab与truncate_embed是否正常工作
+ Embedding = GPT2Embedding
+ weight_path = 'test/data_for_tests/embedding/small_gpt2'
+ vocab = Vocabulary().add_word_lst("this is a texta and".split())
+ embed1 = Embedding(vocab, model_dir_or_name=weight_path,layers=list(range(3)),
+ only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1)
+ # embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS]
+ # self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab))
+
+ embed2 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)),
+ only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1)
+ # embed_bpe_vocab_size = num_word # 排除NotInBERT
+ # self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab))
+
+ embed3 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)),
+ only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1)
+ # embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS]
+ # self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab))
+
+ embed4 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)),
+ only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1)
+ # embed_bpe_vocab_size = num_word+1 # 新增##a
+ # self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab))
+
+ # 测试各种情况下以下tensor的值是相等的
+ embed1.eval()
+ embed2.eval()
+ embed3.eval()
+ embed4.eval()
+ tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]])
+ t1 = embed1(tensor)
+ t2 = embed2(tensor)
+ t3 = embed3(tensor)
+ t4 = embed4(tensor)
+
+ self.assertEqual((t1-t2).sum(), 0)
+ self.assertEqual((t1-t3).sum(), 0)
+ self.assertEqual((t1-t4).sum(), 0)
+
+ def test_gpt2_tokenizer(self):
+ from fastNLP.modules.tokenizer import GPT2Tokenizer
+
+ tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2')
+ print(tokenizer.encode("this is a texta a sentence"))
+ print(tokenizer.encode('this is'))
+
+ def test_gpt2_embed_eq_gpt2_piece_encoder(self):
+ # 主要检查一下embedding的结果与wordpieceencoder的结果是否一致
+ weight_path = 'test/data_for_tests/embedding/small_gpt2'
+ ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]})
+ encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path)
+ encoder.eval()
+ encoder.index_datasets(ds, field_name='words')
+ word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
+ word_pieces_res = encoder(word_pieces)
+
+ vocab = Vocabulary()
+ vocab.from_dataset(ds, field_name='words')
+ vocab.index_dataset(ds, field_name='words', new_field_name='words')
+ ds.set_input('words')
+ words = torch.LongTensor(ds['words'].get([0, 1]))
+ embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, pool_method='first')
+ embed.eval()
+ words_res = embed(words)
+
+ # 检查word piece什么的是正常work的
+ self.assertEqual((word_pieces_res[0, :4]-words_res[0, :4]).sum(), 0)
+ self.assertEqual((word_pieces_res[0, 5:]-words_res[0, 4:]).sum(), 0)
+ self.assertEqual((word_pieces_res[1, :2]-words_res[1, :2]).sum(), 0)
+
+
+class TestGPT2WordPieceEncoder(unittest.TestCase):
+ @unittest.skipIf(True, "Only for local debugging")
+ def test_eq_transformers(self):
+ # 测试能否正确得到类似于transformers的结果
+ weight_path = ''
+
+ # tokenizer = transformers.GPT2Tokenizer.from_pretrained(weight_path)
+
+ ds = DataSet({'words': ["this this this a is texta model vocab".split(), 'this is'.split()]})
+
+ import transformers
+ input1 = ' '.join(ds[0]['words'])
+ input2 = ' '.join(ds[1]['words'])
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained(weight_path)
+ idx_list1 = tokenizer.encode(input1)
+ idx_list2 = tokenizer.encode(input2)
+
+ pad_value = tokenizer.encode('<|endoftext|>')[0]
+ tensor = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(idx_list1),
+ torch.LongTensor(idx_list2)],
+ batch_first=True,
+ padding_value=pad_value)
+ gpt2 = transformers.GPT2Model.from_pretrained(weight_path, output_hidden_states=True)
+ gpt2.eval()
+ tensor = tensor
+ output, _, trans_hidden_states = gpt2(tensor, attention_mask=tensor.ne(pad_value))
+
+ encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, layers=list(range(13)))
+ encoder.eval()
+ encoder.index_datasets(ds, field_name='words', add_endoftext=False)
+ word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
+
+ self.assertEqual(idx_list1, ds[0]['word_pieces'])
+ self.assertEqual(idx_list2, ds[1]['word_pieces'])
+
+ word_pieces_res = encoder(word_pieces)
+
+ self.assertEqual((torch.cat(trans_hidden_states, dim=-1)-word_pieces_res).sum(), 0)
+
+ @unittest.skipIf(True, "Only for local usage")
+ def test_generate_small_gpt2(self):
+ # 因为GPT2使用的是GPT2的tokenizer,所以没办法直接生成权重,需要用点下面的方式
+ weight_path = ''
+ tokenizer = GPT2Tokenizer.from_pretrained(weight_path)
+
+ used_pairs = {}
+ used_vocab = {}
+ # 修改这里即可获得更多的sentence的数据
+ sent1 = "This is a demo sentence"
+ sent2 = "another demo"
+ sent3 = 'this is a texta model vocab'
+ all_tokens = []
+
+ for sent in [sent1, sent2, sent3]:
+ tokens = []
+ for word in sent.split():
+ word = ' '+ word
+ token = "".join(
+ tokenizer.byte_encoder[b] for b in word.encode("utf-8")
+ )
+ _token, _used_pairs = tokenizer.get_used_merge_pair_vocab(token)
+ tokens.extend(_token.split())
+ used_pairs.update(_used_pairs)
+ all_tokens.extend(tokens)
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
+ used_vocab.update({t:i for t,i in zip(tokens, token_ids)})
+
+ print(used_pairs)
+ import json
+ with open('test/data_for_tests/embedding/small_gpt2/vocab.json', 'w') as f:
+ new_used_vocab = {}
+ for idx, key in enumerate(used_vocab.keys()):
+ new_used_vocab[key] = len(new_used_vocab)
+ new_used_vocab['<|endoftext|>'] = len(new_used_vocab)
+ for i in range(65, 91):
+ if chr(i) not in new_used_vocab:
+ new_used_vocab[chr(i)] = len(new_used_vocab)
+ for i in range(97, 123):
+ if chr(i) not in new_used_vocab:
+ new_used_vocab[chr(i)] = len(new_used_vocab)
+
+ json.dump(new_used_vocab, f)
+
+ with open('test/data_for_tests/embedding/small_gpt2/merges.txt', 'w') as f:
+ f.write('#version: small\n')
+ for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])):
+ f.write('{} {}\n'.format(k[0], k[1]))
+
+ new_tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2')
+ new_all_tokens = []
+ for sent in [sent1, sent2, sent3]:
+ tokens = new_tokenizer.tokenize(sent, add_prefix_space=True)
+ new_all_tokens.extend(tokens)
+ print(all_tokens, new_all_tokens)
+
+ self.assertSequenceEqual(all_tokens, new_all_tokens)
+ config = {
+ "architectures": [
+ "GPT2LMHeadModel"
+ ],
+ "initializer_range": 0.02,
+ "layer_norm_epsilon": 1e-05,
+ "n_ctx": 20,
+ "n_embd": 16,
+ "n_head": 4,
+ "n_layer": 2,
+ "n_positions": 20,
+ "vocab_size": len(new_used_vocab)
+ }
+ with open('test/data_for_tests/embedding/small_gpt2/config.json', 'w') as f:
+ json.dump(config, f)
+
+ # 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现
+ from fastNLP.modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Config
+
+ config = GPT2Config.from_pretrained('test/data_for_tests/embedding/small_gpt2')
+
+ model = GPT2LMHeadModel(config)
+ torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin')
+ print(model(torch.LongTensor([[0,1,2,3]])))
+
+ def test_gpt2_word_piece_encoder(self):
+ # 主要检查可以运行
+ weight_path = 'test/data_for_tests/embedding/small_gpt2'
+ ds = DataSet({'words': ["this is a test sentence".split()]})
+ embed = GPT2WordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1)
+ embed.index_datasets(ds, field_name='words')
+ self.assertTrue(ds.has_field('word_pieces'))
+ result = embed(torch.LongTensor([[1, 2, 3, 4]]))
+
+ embed = GPT2WordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1,
+ language_model=True)
+ embed.index_datasets(ds, field_name='words')
+ self.assertTrue(ds.has_field('word_pieces'))
+ result = embed(torch.LongTensor([[1, 2, 3, 4]]))
+
+ def test_generate(self):
+ weight_path = 'test/data_for_tests/embedding/small_gpt2'
+
+ encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True)
+
+ # 测试一下各项东西是否正常work
+ print(encoder.generate_from_str('this', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0,
+ repetition_penalty=1.0, length_penalty=1.0))
+ print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0,
+ repetition_penalty=1.0, length_penalty=1.0))
+ print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0,
+ repetition_penalty=2.0, length_penalty=1.5))
diff --git a/test/embeddings/test_roberta_embedding.py b/test/embeddings/test_roberta_embedding.py
new file mode 100644
index 00000000..c2e80a8a
--- /dev/null
+++ b/test/embeddings/test_roberta_embedding.py
@@ -0,0 +1,252 @@
+
+import unittest
+
+import torch
+import os
+
+from fastNLP import DataSet, Vocabulary
+from fastNLP.embeddings.roberta_embedding import RobertaWordPieceEncoder, RobertaEmbedding
+
+
+class TestRobertWordPieceEncoder(unittest.TestCase):
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_download(self):
+ vocab = Vocabulary().add_word_lst("This is a test .".split())
+ embed = RobertaEmbedding(vocab, model_dir_or_name='en')
+ words = torch.LongTensor([[2, 3, 4, 0]])
+ print(embed(words).size())
+
+ for pool_method in ['first', 'last', 'max', 'avg']:
+ for include_cls_sep in [True, False]:
+ embed = RobertaEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method,
+ include_cls_sep=include_cls_sep)
+ print(embed(words).size())
+
+ def test_robert_word_piece_encoder(self):
+ # 可正常运行即可
+ weight_path = 'test/data_for_tests/embedding/small_roberta'
+ encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1)
+ ds = DataSet({'words': ["this is a test . [SEP]".split()]})
+ encoder.index_datasets(ds, field_name='words')
+ self.assertTrue(ds.has_field('word_pieces'))
+ result = encoder(torch.LongTensor([[1,2,3,4]]))
+
+ def test_roberta_embed_eq_roberta_piece_encoder(self):
+ # 主要检查一下embedding的结果与wordpieceencoder的结果是否一致
+ weight_path = 'test/data_for_tests/embedding/small_roberta'
+ ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]})
+ encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path)
+ encoder.eval()
+ encoder.index_datasets(ds, field_name='words')
+ word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
+ word_pieces_res = encoder(word_pieces)
+
+ vocab = Vocabulary()
+ vocab.from_dataset(ds, field_name='words')
+ vocab.index_dataset(ds, field_name='words', new_field_name='words')
+ ds.set_input('words')
+ words = torch.LongTensor(ds['words'].get([0, 1]))
+ embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path,
+ pool_method='first', include_cls_sep=True, pooled_cls=False)
+ embed.eval()
+ words_res = embed(words)
+
+ # 检查word piece什么的是正常work的
+ self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0)
+ self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0)
+ self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0)
+
+ @unittest.skipIf(True, "Only for local debugging")
+ def test_eq_transformers(self):
+ weight_path = ''
+ ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]})
+ encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path)
+ encoder.eval()
+ encoder.index_datasets(ds, field_name='words')
+ word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
+ word_pieces_res = encoder(word_pieces)
+
+ import transformers
+ input1 = ' '.join(ds[0]['words'])
+ input2 = ' '.join(ds[1]['words'])
+ tokenizer = transformers.RobertaTokenizer.from_pretrained(weight_path)
+ idx_list1 = tokenizer.encode(input1)
+ idx_list2 = tokenizer.encode(input2)
+ self.assertEqual(idx_list1, ds[0]['word_pieces'])
+ self.assertEqual(idx_list2, ds[1]['word_pieces'])
+
+ pad_value = tokenizer.encode('')[0]
+ tensor = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(idx_list1),
+ torch.LongTensor(idx_list2)],
+ batch_first=True,
+ padding_value=pad_value)
+ roberta = transformers.RobertaModel.from_pretrained(weight_path, output_hidden_states=True)
+ roberta.eval()
+ output, pooled_output, hidden_states = roberta(tensor, attention_mask=tensor.ne(pad_value))
+
+ self.assertEqual((output-word_pieces_res).sum(), 0)
+
+ @unittest.skipIf(True, "Only for local usage")
+ def test_generate_small_roberta(self):
+ """
+ 因为Roberta使用的是GPT2的tokenizer,所以没办法直接生成权重,需要用点下面的方式
+
+ :return:
+ """
+ weight_path = ''
+ from fastNLP.modules.tokenizer import RobertaTokenizer
+ tokenizer = RobertaTokenizer.from_pretrained(weight_path)
+
+ used_pairs = {}
+ used_vocab = {}
+ # 修改这里即可获得更多的sentence的数据
+ sent1 = "This is a demo sentence"
+ sent2 = "another demo"
+ sent3 = 'this is a texta model vocab'
+ all_tokens = []
+
+ for sent in [sent1, sent2, sent3]:
+ tokens = []
+ for word in sent.split():
+ word = ' '+ word
+ token = "".join(
+ tokenizer.byte_encoder[b] for b in word.encode("utf-8")
+ )
+ _token, _used_pairs = tokenizer.get_used_merge_pair_vocab(token)
+ tokens.extend(_token.split())
+ used_pairs.update(_used_pairs)
+ all_tokens.extend(tokens)
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
+ used_vocab.update({t:i for t,i in zip(tokens, token_ids)})
+
+ import json
+ with open('test/data_for_tests/embedding/small_roberta/vocab.json', 'w') as f:
+ new_used_vocab = {}
+ for token in ['', '', '', '', '']: # 必须为1
+ new_used_vocab[token] = len(new_used_vocab)
+ for i in range(65, 91):
+ if chr(i) not in new_used_vocab:
+ new_used_vocab[chr(i)] = len(new_used_vocab)
+ for i in range(97, 123):
+ if chr(i) not in new_used_vocab:
+ new_used_vocab[chr(i)] = len(new_used_vocab)
+ for idx, key in enumerate(used_vocab.keys()):
+ if key not in new_used_vocab:
+ new_used_vocab[key] = len(new_used_vocab)
+ json.dump(new_used_vocab, f)
+
+ with open('test/data_for_tests/embedding/small_roberta/merges.txt', 'w') as f:
+ f.write('#version: tiny\n')
+ for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])):
+ f.write('{} {}\n'.format(k[0], k[1]))
+
+ config = {
+ "architectures": [
+ "RobertaForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "finetuning_task": None,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 16,
+ "initializer_range": 0.02,
+ "intermediate_size": 20,
+ "layer_norm_eps": 1e-05,
+ "max_position_embeddings": 20,
+ "num_attention_heads": 4,
+ "num_hidden_layers": 2,
+ "num_labels": 2,
+ "output_attentions": False,
+ "output_hidden_states": False,
+ "torchscript": False,
+ "type_vocab_size": 1,
+ "vocab_size": len(new_used_vocab)
+ }
+ with open('test/data_for_tests/embedding/small_roberta/config.json', 'w') as f:
+ json.dump(config, f)
+
+ new_tokenizer = RobertaTokenizer.from_pretrained('test/data_for_tests/embedding/small_roberta')
+ new_all_tokens = []
+ for sent in [sent1, sent2, sent3]:
+ tokens = new_tokenizer.tokenize(sent, add_prefix_space=True)
+ new_all_tokens.extend(tokens)
+ print(all_tokens, new_all_tokens)
+
+ self.assertSequenceEqual(all_tokens, new_all_tokens)
+
+ # 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现
+ from fastNLP.modules.encoder.roberta import RobertaModel, BertConfig
+
+ config = BertConfig.from_json_file('test/data_for_tests/embedding/small_roberta/config.json')
+
+ model = RobertaModel(config)
+ torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin')
+ print(model(torch.LongTensor([[0,1,2,3]])))
+
+
+class TestRobertaEmbedding(unittest.TestCase):
+ def test_roberta_embedding_1(self):
+ weight_path = 'test/data_for_tests/embedding/small_roberta'
+ vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInRoberta".split())
+ embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1)
+ requires_grad = embed.requires_grad
+ embed.requires_grad = not requires_grad
+ embed.train()
+ words = torch.LongTensor([[2, 3, 4, 1]])
+ result = embed(words)
+ self.assertEqual(result.size(), (1, 4, 16))
+
+ embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1,
+ only_use_pretrain_bpe=True)
+ embed.eval()
+ words = torch.LongTensor([[2, 3, 4, 1]])
+ result = embed(words)
+ self.assertEqual(result.size(), (1, 4, 16))
+
+ # 自动截断而不报错
+ embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1,
+ only_use_pretrain_bpe=True, auto_truncate=True)
+ words = torch.LongTensor([[2, 3, 4, 1]*10,
+ [2, 3]+[0]*38])
+ result = embed(words)
+ self.assertEqual(result.size(), (2, 40, 16))
+
+ def test_roberta_ebembedding_2(self):
+ # 测试only_use_pretrain_vocab与truncate_embed是否正常工作
+ Embedding = RobertaEmbedding
+ weight_path = 'test/data_for_tests/embedding/small_roberta'
+ vocab = Vocabulary().add_word_lst("this is a texta and".split())
+ embed1 = Embedding(vocab, model_dir_or_name=weight_path,layers=list(range(3)),
+ only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1)
+ # embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS]
+ # self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab))
+
+ embed2 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)),
+ only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1)
+ # embed_bpe_vocab_size = num_word # 排除NotInBERT
+ # self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab))
+
+ embed3 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)),
+ only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1)
+ # embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS]
+ # self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab))
+
+ embed4 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)),
+ only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1)
+ # embed_bpe_vocab_size = num_word+1 # 新增##a
+ # self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab))
+
+ # 测试各种情况下以下tensor的值是相等的
+ embed1.eval()
+ embed2.eval()
+ embed3.eval()
+ embed4.eval()
+ tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]])
+ t1 = embed1(tensor)
+ t2 = embed2(tensor)
+ t3 = embed3(tensor)
+ t4 = embed4(tensor)
+
+ self.assertEqual((t1-t2).sum(), 0)
+ self.assertEqual((t1-t3).sum(), 0)
+ self.assertEqual((t1-t4).sum(), 0)
diff --git a/test/modules/encoder/test_bert.py b/test/modules/encoder/test_bert.py
new file mode 100644
index 00000000..35802811
--- /dev/null
+++ b/test/modules/encoder/test_bert.py
@@ -0,0 +1,24 @@
+import unittest
+
+
+from fastNLP.modules import BertTokenizer
+
+
+class TestBertTokenizer(unittest.TestCase):
+ def test_run(self):
+ # 测试支持的两种encode方式
+ tokenizer = BertTokenizer.from_pretrained('test/data_for_tests/embedding/small_bert')
+
+ tokens1 = tokenizer.encode("This is a demo")
+ tokens2 = tokenizer.encode("This is a demo")
+ tokens3 = tokenizer.encode("This is a demo".split())
+ tokens4 = tokenizer.encode("This is a demo".split())
+
+ self.assertEqual(len(tokens1)-2, len(tokens2))
+ self.assertEqual(len(tokens3)-2, len(tokens4))
+
+ self.assertEqual(tokens1[0], tokenizer.cls_index)
+ self.assertEqual(tokens1[-1], tokenizer.sep_index)
+
+
+