@@ -217,7 +217,8 @@ class BatchIter: | |||||
class DataSetIter(BatchIter): | class DataSetIter(BatchIter): | ||||
r""" | r""" | ||||
DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||||
DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出,通过使用DataSetIter,可以不需要考虑 | |||||
输入的padding(由DataSet中每列的Padder决定了)以及不需要考虑将数据转为tensor。 | |||||
组成 `x` 和 `y`:: | 组成 `x` 和 `y`:: | ||||
batch = DataSetIter(data_set, batch_size=16, sampler=SequentialSampler()) | batch = DataSetIter(data_set, batch_size=16, sampler=SequentialSampler()) | ||||
@@ -226,10 +227,8 @@ class DataSetIter(BatchIter): | |||||
# do stuff ... | # do stuff ... | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None, collate_fn=None, | |||||
batch_sampler=None): | |||||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, | |||||
drop_last=False, timeout=0, worker_init_fn=None, batch_sampler=None): | |||||
r""" | r""" | ||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | ||||
@@ -245,13 +244,12 @@ class DataSetIter(BatchIter): | |||||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | ||||
:param timeout: 生成一个batch的timeout值 | :param timeout: 生成一个batch的timeout值 | ||||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | ||||
:param collate_fn: 用于将样本组合成batch的函数 | |||||
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 | :param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 | ||||
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 | 当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 | ||||
""" | """ | ||||
assert isinstance(dataset, DataSet) | assert isinstance(dataset, DataSet) | ||||
dataset = DataSetGetter(dataset, as_numpy) | dataset = DataSetGetter(dataset, as_numpy) | ||||
collate_fn = dataset.collate_fn if collate_fn is None else collate_fn | |||||
collate_fn = dataset.collate_fn | |||||
if batch_sampler is not None: | if batch_sampler is not None: | ||||
batch_size = 1 | batch_size = 1 | ||||
sampler = None | sampler = None | ||||
@@ -272,8 +270,9 @@ class DataSetIter(BatchIter): | |||||
class TorchLoaderIter(BatchIter): | class TorchLoaderIter(BatchIter): | ||||
r""" | r""" | ||||
与DataSetIter类似,但可以用于非fastNLP的数据容器对象,然后将其传入到Trainer中。 | |||||
只需要保证数据容器实现了实现了以下的方法 | |||||
与DataSetIter类似,但可以用于非fastNLP的数据容器对象,以及可以实现完全自定义的生成batch的方式,然后与Trainer,Tester可以实现 | |||||
与DataSetIter一样的对接。 | |||||
需要保证传入的数据容器实现了实现了以下的方法 | |||||
Example:: | Example:: | ||||
@@ -293,7 +292,7 @@ class TorchLoaderIter(BatchIter): | |||||
return self.num_samples | return self.num_samples | ||||
# 需要实现collact_fn将数据转换为tensor | # 需要实现collact_fn将数据转换为tensor | ||||
def collact_fn(data_list): | |||||
def collate_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | ||||
xs, ys = [], [] | xs, ys = [], [] | ||||
for l in data_list: | for l in data_list: | ||||
@@ -302,10 +301,10 @@ class TorchLoaderIter(BatchIter): | |||||
ys.append(y) | ys.append(y) | ||||
# 不需要转移到gpu,Trainer或Tester会将其转移到model所在的device | # 不需要转移到gpu,Trainer或Tester会将其转移到model所在的device | ||||
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | ||||
return {'x':x, 'y':y}, {'y':y} | |||||
return {'x':x, 'y':y}, {'y':y} # 第一个dict中内容类似于DataSet中的input列,第二个dict的内容类似于target列 | |||||
udf_dataset = UdfDataSet(10) | udf_dataset = UdfDataSet(10) | ||||
dataset = TorchLoaderIter(udf_dataset, collate_fn=collact_fn) | |||||
dataset = TorchLoaderIter(udf_dataset, collate_fn=collate_fn) | |||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -362,7 +361,7 @@ class TorchLoaderIter(BatchIter): | |||||
def __len__(self): | def __len__(self): | ||||
return self.num_samples | return self.num_samples | ||||
def collact_fn(data_list): | |||||
def collate_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | ||||
xs, ys = [], [] | xs, ys = [], [] | ||||
for l in data_list: | for l in data_list: | ||||
@@ -370,10 +369,10 @@ class TorchLoaderIter(BatchIter): | |||||
xs.append(x) | xs.append(x) | ||||
ys.append(y) | ys.append(y) | ||||
x, y = torch.FloatTensor(xs), torch.FloatTensor(ys) | x, y = torch.FloatTensor(xs), torch.FloatTensor(ys) | ||||
return {'x': x, 'y': y}, {'y': y} | |||||
return {'x': x, 'y': y}, {'y': y} # 第一个dict中内容类似于DataSet中的input列,第二个dict的内容类似于target列 | |||||
file_data = FileDataSet(tmp_file_path) | file_data = FileDataSet(tmp_file_path) | ||||
dataset = TorchLoaderIter(file_data, collate_fn=collact_fn) | |||||
dataset = TorchLoaderIter(file_data, collate_fn=collate_fn) | |||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -214,11 +214,8 @@ class DistTrainer(): | |||||
def _get_data_iter(self, dataset): | def _get_data_iter(self, dataset): | ||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
return DataSetIter( | |||||
dataset=dataset, batch_size=self.batch_size_per_gpu, | |||||
num_workers=self.num_data_workers, sampler=self.sampler, | |||||
drop_last=self.drop_last | |||||
) | |||||
return DataSetIter(dataset=dataset, batch_size=self.batch_size_per_gpu, sampler=self.sampler, | |||||
num_workers=self.num_data_workers, drop_last=self.drop_last) | |||||
elif isinstance(dataset, BatchIter): | elif isinstance(dataset, BatchIter): | ||||
return dataset | return dataset | ||||
else: | else: | ||||
@@ -107,8 +107,8 @@ class Tester(object): | |||||
self.logger = logger | self.logger = logger | ||||
if isinstance(data, DataSet): | if isinstance(data, DataSet): | ||||
self.data_iterator = DataSetIter( | |||||
dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler()) | |||||
self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=SequentialSampler(), | |||||
num_workers=num_workers) | |||||
elif isinstance(data, BatchIter): | elif isinstance(data, BatchIter): | ||||
self.data_iterator = data | self.data_iterator = data | ||||
else: | else: | ||||
@@ -487,8 +487,8 @@ class Trainer(object): | |||||
sampler.set_batch_size(batch_size) | sampler.set_batch_size(batch_size) | ||||
if isinstance(train_data, DataSet): | if isinstance(train_data, DataSet): | ||||
self.data_iterator = DataSetIter( | |||||
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||||
self.data_iterator = DataSetIter(dataset=train_data, batch_size=batch_size, sampler=sampler, | |||||
num_workers=num_workers, drop_last=drop_last) | |||||
elif isinstance(train_data, BatchIter): | elif isinstance(train_data, BatchIter): | ||||
self.data_iterator = train_data | self.data_iterator = train_data | ||||
train_data = train_data.dataset | train_data = train_data.dataset | ||||
@@ -12,16 +12,26 @@ __all__ = [ | |||||
"ElmoEmbedding", | "ElmoEmbedding", | ||||
"BertEmbedding", | "BertEmbedding", | ||||
"BertWordPieceEncoder", | "BertWordPieceEncoder", | ||||
"RobertaEmbedding", | |||||
"RobertaWordPieceEncoder", | |||||
"GPT2Embedding", | |||||
"GPT2WordPieceEncoder", | |||||
"StackEmbedding", | "StackEmbedding", | ||||
"LSTMCharEmbedding", | "LSTMCharEmbedding", | ||||
"CNNCharEmbedding", | "CNNCharEmbedding", | ||||
"get_embeddings", | "get_embeddings", | ||||
] | ] | ||||
from .embedding import Embedding, TokenEmbedding | from .embedding import Embedding, TokenEmbedding | ||||
from .static_embedding import StaticEmbedding | from .static_embedding import StaticEmbedding | ||||
from .elmo_embedding import ElmoEmbedding | from .elmo_embedding import ElmoEmbedding | ||||
from .bert_embedding import BertEmbedding, BertWordPieceEncoder | from .bert_embedding import BertEmbedding, BertWordPieceEncoder | ||||
from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder | |||||
from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding | |||||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | ||||
from .stack_embedding import StackEmbedding | from .stack_embedding import StackEmbedding | ||||
from .utils import get_embeddings | from .utils import get_embeddings | ||||
@@ -11,6 +11,7 @@ __all__ = [ | |||||
import collections | import collections | ||||
import warnings | import warnings | ||||
from itertools import chain | from itertools import chain | ||||
from functools import partial | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -20,7 +21,8 @@ from .contextual_embedding import ContextualEmbedding | |||||
from ..core import logger | from ..core import logger | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | ||||
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | |||||
from ..modules.encoder.bert import BertModel | |||||
from ..modules.tokenizer import BertTokenizer | |||||
class BertEmbedding(ContextualEmbedding): | class BertEmbedding(ContextualEmbedding): | ||||
@@ -31,6 +33,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
BertEmbedding可以支持自动下载权重,当前支持的模型: | BertEmbedding可以支持自动下载权重,当前支持的模型: | ||||
en: base-cased | en: base-cased | ||||
en-base-uncased: | |||||
en-large-cased-wwm: | en-large-cased-wwm: | ||||
en-large-cased: | en-large-cased: | ||||
en-large-uncased: | en-large-uncased: | ||||
@@ -63,7 +66,8 @@ class BertEmbedding(ContextualEmbedding): | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | ||||
权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。 | 权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。 | ||||
:param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 | :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 | ||||
从0开始,可以以负数去索引倒数几层。 | |||||
从0开始,可以以负数去索引倒数几层。 layer=0为embedding层(包括wordpiece embedding, | |||||
position embedding和segment embedding) | |||||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | ||||
中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 | 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 | ||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | ||||
@@ -95,20 +99,22 @@ class BertEmbedding(ContextualEmbedding): | |||||
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | ||||
" faster speed.") | " faster speed.") | ||||
self._word_sep_index = None | |||||
self._word_sep_index = -100 | |||||
if '[SEP]' in vocab: | if '[SEP]' in vocab: | ||||
self._word_sep_index = vocab['[SEP]'] | self._word_sep_index = vocab['[SEP]'] | ||||
self._word_cls_index = -100 | |||||
if '[CLS]' in vocab: | |||||
self._word_cls_index = vocab['CLS'] | |||||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | ||||
truncate_embed = kwargs.get('truncate_embed', True) | truncate_embed = kwargs.get('truncate_embed', True) | ||||
min_freq = kwargs.get('min_freq', 2) | min_freq = kwargs.get('min_freq', 2) | ||||
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | ||||
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | ||||
self._sep_index = self.model._sep_index | |||||
self._cls_index = self.model._cls_index | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
@@ -141,15 +147,16 @@ class BertEmbedding(ContextualEmbedding): | |||||
""" | """ | ||||
if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
not_sep_mask = words.ne(self._sep_index) | |||||
not_cls_mask = words.ne(self._cls_index) | |||||
if self._word_sep_index: | |||||
not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index)) | |||||
replaceable_mask = not_sep_mask.__and__(not_cls_mask) | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | ||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(0) | |||||
mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk | |||||
pad_mask = words.ne(self._word_pad_index) | |||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||||
if self._word_sep_index!=-100: | |||||
not_sep_mask = words.ne(self._word_sep_index) | |||||
mask = mask.__and__(not_sep_mask) | |||||
if self._word_cls_index!=-100: | |||||
not_cls_mask = words.ne(self._word_cls_index) | |||||
mask = mask.__and__(not_cls_mask) | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
return words | return words | ||||
@@ -177,7 +184,8 @@ class BertWordPieceEncoder(nn.Module): | |||||
r""" | r""" | ||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | ||||
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||||
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding, | |||||
position embedding和segment embedding) | |||||
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。 | :param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。 | ||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | ||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | ||||
@@ -185,7 +193,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | |||||
self.model = _BertWordPieceModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | |||||
self._sep_index = self.model._sep_index | self._sep_index = self.model._sep_index | ||||
self._cls_index = self.model._cls_index | self._cls_index = self.model._cls_index | ||||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | self._wordpiece_pad_index = self.model._wordpiece_pad_index | ||||
@@ -217,7 +225,8 @@ class BertWordPieceEncoder(nn.Module): | |||||
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 | :param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 | ||||
:return: | :return: | ||||
""" | """ | ||||
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | |||||
self.model.index_datasets(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | |||||
def forward(self, word_pieces, token_type_ids=None): | def forward(self, word_pieces, token_type_ids=None): | ||||
r""" | r""" | ||||
@@ -262,7 +271,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
return words | return words | ||||
class _WordBertModel(nn.Module): | |||||
class _BertWordModel(nn.Module): | |||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | ||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | ||||
only_use_pretrain_bpe=False, truncate_embed=True): | only_use_pretrain_bpe=False, truncate_embed=True): | ||||
@@ -273,13 +282,18 @@ class _WordBertModel(nn.Module): | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | self._max_position_embeddings = self.encoder.config.max_position_embeddings | ||||
# 检查encoder_layer_number是否合理 | # 检查encoder_layer_number是否合理 | ||||
encoder_layer_number = len(self.encoder.encoder.layer) | encoder_layer_number = len(self.encoder.encoder.layer) | ||||
self.layers = list(map(int, layers.split(','))) | |||||
if isinstance(layers, list): | |||||
self.layers = [int(l) for l in layers] | |||||
elif isinstance(layers, str): | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | |||||
raise TypeError("`layers` only supports str or list[int]") | |||||
for layer in self.layers: | for layer in self.layers: | ||||
if layer < 0: | if layer < 0: | ||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | ||||
f"a bert model with {encoder_layer_number} layers." | f"a bert model with {encoder_layer_number} layers." | ||||
else: | else: | ||||
assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a bert model with {encoder_layer_number} layers." | f"a bert model with {encoder_layer_number} layers." | ||||
assert pool_method in ('avg', 'max', 'first', 'last') | assert pool_method in ('avg', 'max', 'first', 'last') | ||||
@@ -295,7 +309,8 @@ class _WordBertModel(nn.Module): | |||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | ||||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | ||||
new_add_to_bpe_vocab = 0 | new_add_to_bpe_vocab = 0 | ||||
unsegment_word = 0 | |||||
unsegment_count = 0 | |||||
if '[sep]' in vocab: | if '[sep]' in vocab: | ||||
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | ||||
if "[CLS]" in vocab: | if "[CLS]" in vocab: | ||||
@@ -318,7 +333,8 @@ class _WordBertModel(nn.Module): | |||||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | ||||
word_piece_dict[word] = 1 # 新增一个值 | word_piece_dict[word] = 1 # 新增一个值 | ||||
new_add_to_bpe_vocab += 1 | new_add_to_bpe_vocab += 1 | ||||
unsegment_word += 1 | |||||
unsegment_count += 1 | |||||
continue | continue | ||||
for word_piece in word_pieces: | for word_piece in word_pieces: | ||||
word_piece_dict[word_piece] = 1 | word_piece_dict[word_piece] = 1 | ||||
@@ -331,21 +347,28 @@ class _WordBertModel(nn.Module): | |||||
new_word_piece_vocab = collections.OrderedDict() | new_word_piece_vocab = collections.OrderedDict() | ||||
for index, token in enumerate(['[PAD]', '[UNK]']): | for index, token in enumerate(['[PAD]', '[UNK]']): | ||||
word_piece_dict.pop(token, None) | |||||
embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] | |||||
new_word_piece_vocab[token] = index | |||||
index = word_piece_dict.pop(token, None) | |||||
if index is not None: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.vocab[token]] | |||||
for token in word_piece_dict.keys(): | for token in word_piece_dict.keys(): | ||||
if token not in new_word_piece_vocab: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
index = new_word_piece_vocab[token] | |||||
if token in self.tokenzier.vocab: | if token in self.tokenzier.vocab: | ||||
embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab[token]] | |||||
embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] | |||||
else: | else: | ||||
embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab['[UNK]']] | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
embed.weight.data[index] = original_embed[self.tokenzier.vocab['[UNK]']] | |||||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | ||||
self.encoder.embeddings.word_embeddings = embed | self.encoder.embeddings.word_embeddings = embed | ||||
if only_use_pretrain_bpe: | |||||
logger.info(f"{unsegment_word} words are unsegmented.") | |||||
else: | |||||
logger.info(f"{unsegment_word} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||||
self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||||
if unsegment_count>0: | |||||
if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||||
logger.info(f"{unsegment_count} words are unsegmented.") | |||||
else: | |||||
logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||||
word_to_wordpieces = [] | word_to_wordpieces = [] | ||||
word_pieces_lengths = [] | word_pieces_lengths = [] | ||||
@@ -379,8 +402,8 @@ class _WordBertModel(nn.Module): | |||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), | batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), | ||||
0) # batch_size x max_len | 0) # batch_size x max_len | ||||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | ||||
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||||
if word_piece_length + 2 > self._max_position_embeddings: | |||||
max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||||
if max_word_piece_length + 2 > self._max_position_embeddings: | |||||
if self.auto_truncate: | if self.auto_truncate: | ||||
word_pieces_lengths = word_pieces_lengths.masked_fill( | word_pieces_lengths = word_pieces_lengths.masked_fill( | ||||
word_pieces_lengths + 2 > self._max_position_embeddings, | word_pieces_lengths + 2 > self._max_position_embeddings, | ||||
@@ -392,7 +415,7 @@ class _WordBertModel(nn.Module): | |||||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | ||||
# +2是由于需要加入[CLS]与[SEP] | # +2是由于需要加入[CLS]与[SEP] | ||||
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), | |||||
word_pieces = words.new_full((batch_size, min(max_word_piece_length + 2, self._max_position_embeddings)), | |||||
fill_value=self._wordpiece_pad_index) | fill_value=self._wordpiece_pad_index) | ||||
attn_masks = torch.zeros_like(word_pieces) | attn_masks = torch.zeros_like(word_pieces) | ||||
# 1. 获取words的word_pieces的id,以及对应的span范围 | # 1. 获取words的word_pieces的id,以及对应的span范围 | ||||
@@ -435,19 +458,19 @@ class _WordBertModel(nn.Module): | |||||
if self.pool_method == 'first': | if self.pool_method == 'first': | ||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | ||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0) | |||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | ||||
elif self.pool_method == 'last': | elif self.pool_method == 'last': | ||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | ||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0) | |||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | ||||
for l_index, l in enumerate(self.layers): | for l_index, l in enumerate(self.layers): | ||||
output_layer = bert_outputs[l] | output_layer = bert_outputs[l] | ||||
real_word_piece_length = output_layer.size(1) - 2 | real_word_piece_length = output_layer.size(1) - 2 | ||||
if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 | |||||
if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 | |||||
paddings = output_layer.new_zeros(batch_size, | paddings = output_layer.new_zeros(batch_size, | ||||
word_piece_length - real_word_piece_length, | |||||
max_word_piece_length - real_word_piece_length, | |||||
output_layer.size(2)) | output_layer.size(2)) | ||||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | ||||
# 从word_piece collapse到word的表示 | # 从word_piece collapse到word的表示 | ||||
@@ -480,3 +503,81 @@ class _WordBertModel(nn.Module): | |||||
# 3. 最终的embedding结果 | # 3. 最终的embedding结果 | ||||
return outputs | return outputs | ||||
class _BertWordPieceModel(nn.Module): | |||||
r""" | |||||
这个模块用于直接计算word_piece的结果. | |||||
""" | |||||
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | |||||
super().__init__() | |||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = BertModel.from_pretrained(model_dir_or_name) | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
if isinstance(layers, list): | |||||
self.layers = [int(l) for l in layers] | |||||
elif isinstance(layers, str): | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | |||||
raise TypeError("`layers` only supports str or list[int]") | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a bert model with {encoder_layer_number} layers." | |||||
else: | |||||
assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a bert model with {encoder_layer_number} layers." | |||||
self._cls_index = self.tokenzier.cls_index | |||||
self._sep_index = self.tokenzier.sep_index | |||||
self._wordpiece_unknown_index = self.tokenzier.unk_index | |||||
self._wordpiece_pad_index = self.tokenzier.pad_index # 需要用于生成word_piece | |||||
self.pooled_cls = pooled_cls | |||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | |||||
r""" | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||||
:param datasets: DataSet对象 | |||||
:param field_name: 基于哪一列index | |||||
:return: | |||||
""" | |||||
encode_func = partial(self.tokenzier.encode, add_special_tokens=add_cls_sep) | |||||
for index, dataset in enumerate(datasets): | |||||
try: | |||||
dataset.apply_field(encode_func, field_name=field_name, new_field_name='word_pieces', | |||||
is_input=True) | |||||
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when processing the {index} dataset.") | |||||
raise e | |||||
def forward(self, word_pieces, token_type_ids=None): | |||||
r""" | |||||
:param word_pieces: torch.LongTensor, batch_size x max_len | |||||
:param token_type_ids: torch.LongTensor, batch_size x max_len | |||||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||||
""" | |||||
batch_size, max_len = word_pieces.size() | |||||
attn_masks = word_pieces.ne(self._wordpiece_pad_index) | |||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||||
output_all_encoded_layers=True) | |||||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||||
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1))) | |||||
for l_index, l in enumerate(self.layers): | |||||
bert_output = bert_outputs[l] | |||||
if l in (len(bert_outputs)-1, -1) and self.pooled_cls: | |||||
bert_output[:, 0] = pooled_cls | |||||
outputs[l_index] = bert_output | |||||
return outputs |
@@ -0,0 +1,656 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"GPT2Embedding", | |||||
"GPT2WordPieceEncoder" | |||||
] | |||||
import warnings | |||||
from functools import partial | |||||
from itertools import chain | |||||
from collections import OrderedDict | |||||
import torch | |||||
from torch import nn | |||||
import numpy as np | |||||
from .contextual_embedding import ContextualEmbedding | |||||
from ..core import logger | |||||
from ..core.utils import _get_model_device | |||||
from ..core.vocabulary import Vocabulary | |||||
from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | |||||
from ..modules.tokenizer import GPT2Tokenizer | |||||
from ..modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Model | |||||
class GPT2Embedding(ContextualEmbedding): | |||||
""" | |||||
使用GPT2对words进行编码的Embedding。 | |||||
GPT2Embedding可以支持自动下载权重,当前支持的模型: | |||||
en: gpt2 | |||||
en-medium: gpt2-medium | |||||
Example:: | |||||
>>> import torch | |||||
>>> from fastNLP import Vocabulary | |||||
>>> from fastNLP.embeddings import BertEmbedding | |||||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||||
>>> embed = GPT2Embedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='4,-2,-1') | |||||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||||
>>> outputs = embed(words) | |||||
>>> outputs.size() | |||||
>>> # torch.Size([1, 5, 3096]) | |||||
""" | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-small', layers: str = '-1', | |||||
pool_method: str = 'first', dropout=0, requires_grad: bool = True, | |||||
auto_truncate: bool = False, language_model: bool = False, **kwargs): | |||||
""" | |||||
:param ~fastNLP.Vocabulary vocab: 词表 | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | |||||
权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。 | |||||
:param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 | |||||
从0开始,可以以负数去索引倒数几层。 | |||||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | |||||
中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param bool requires_grad: 是否需要gradient以更新Bert的权重。 | |||||
:param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 | |||||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | |||||
来进行分类的任务将auto_truncate置为True。 | |||||
:param bool language_model: 是否计算gpt2的lm loss,可以通过get_loss()获取,输入一个batch之后的get_loss调用即为batch的language | |||||
model的loss | |||||
:param **kwargs: | |||||
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||||
建议设置为True。 | |||||
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入GPT2的BPE词表中 | |||||
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||||
""" | |||||
super().__init__(vocab, word_dropout=0, dropout=dropout) | |||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): | |||||
logger.warning("For Chinese GPT, pooled_method should choose from 'first', 'last' in order to achieve" | |||||
" faster speed.") | |||||
warnings.warn("For Chinese GPT, pooled_method should choose from 'first', 'last' in order to achieve" | |||||
" faster speed.") | |||||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||||
truncate_embed = kwargs.get('truncate_embed', True) | |||||
min_freq = kwargs.get('min_freq', 2) | |||||
self.lm_loss =language_model | |||||
self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||||
pool_method=pool_method, auto_truncate=auto_truncate, language_model=language_model, | |||||
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed, | |||||
min_freq=min_freq) | |||||
self.requires_grad = requires_grad | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.config.n_embd | |||||
def _delete_model_weights(self): | |||||
del self.model | |||||
def forward(self, words): | |||||
""" | |||||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||||
删除这两个token的表示。 | |||||
:param torch.LongTensor words: [batch_size, max_len] | |||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||||
""" | |||||
outputs = self._get_sent_reprs(words) | |||||
if outputs is not None: | |||||
return self.dropout(outputs) | |||||
outputs = self.model(words) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return self.dropout(outputs) | |||||
def drop_word(self, words): | |||||
""" | |||||
:param torch.LongTensor words: batch_size x max_len | |||||
:return: | |||||
""" | |||||
if self.word_dropout > 0 and self.training: | |||||
with torch.no_grad(): | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | |||||
return words | |||||
def get_lm_loss(self, release=True): | |||||
""" | |||||
当language_model=True时,可以通过该接口获取当前batch的language model loss的大小 | |||||
:param bool release: 如果为True,获取了lm_loss后在下一次forward完成之前都无法获取lm_loss了 | |||||
:return: torch.FloatTensor([]) | |||||
""" | |||||
if hasattr(self.model, '_lm_loss_value'): | |||||
lm_loss_value = self.model._lm_loss_value | |||||
if release: | |||||
delattr(self.model, '_lm_loss_value') | |||||
return lm_loss_value | |||||
elif self.lm_loss: | |||||
raise RuntimeError("Make sure you have passed a batch into GPT2Embdding before accessing loss.") | |||||
else: | |||||
raise RuntimeError("Initialize your GPT2Embedding with language_model=True.") | |||||
class GPT2WordPieceEncoder(nn.Module): | |||||
""" | |||||
GPT2模型,使用时先使用本模型对应的Tokenizer对数据进行tokenize | |||||
GPT2WordPieceEncoder可以支持自动下载权重,当前支持的模型: | |||||
en: gpt2 | |||||
en-medium: gpt2-medium | |||||
""" | |||||
def __init__(self, model_dir_or_name: str = 'en-small', layers: str = '-1', | |||||
word_dropout=0, dropout=0, requires_grad: bool = True, language_model:bool=False): | |||||
""" | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。 | |||||
:param str,list layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||||
:param float word_dropout: 多大概率将word piece置为<|endoftext|> | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param bool language_model: 是否使用language model | |||||
:param bool requires_grad: 是否需要gradient。 | |||||
""" | |||||
super().__init__() | |||||
self.model = _GPT2WordPieceModel(model_dir_or_name=model_dir_or_name, layers=layers, language_model=language_model) | |||||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.config.n_embd | |||||
self.requires_grad = requires_grad | |||||
self.dropout_layer = nn.Dropout(dropout) | |||||
self._wordpiece_endoftext_index = self.model._endoftext_index | |||||
self.word_dropout = word_dropout | |||||
self.language_model = language_model | |||||
@property | |||||
def embed_size(self): | |||||
return self._embed_size | |||||
@property | |||||
def embedding_dim(self): | |||||
return self._embed_size | |||||
@property | |||||
def num_embedding(self): | |||||
return self.model.encoder.config.vocab_size | |||||
def index_datasets(self, *datasets, field_name, add_endoftext=False, add_prefix_space=True): | |||||
""" | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | |||||
bert的pad value。 | |||||
:param ~fastNLP.DataSet datasets: DataSet对象 | |||||
:param list[str] field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||||
:param bool add_endoftext: 在句子开头加入<|endofline|>。 | |||||
:param bool add_prefix_space: 是否在句首增加空格 | |||||
:return: | |||||
""" | |||||
self.model.index_datasets(*datasets, field_name=field_name, add_endoftext=add_endoftext, | |||||
add_prefix_space=add_prefix_space) | |||||
def forward(self, word_pieces, token_type_ids=None): | |||||
""" | |||||
计算words的bert embedding表示。传入的words中应该在开头包含<|endofline|>。 | |||||
:param word_pieces: batch_size x max_len | |||||
:param token_type_ids: batch_size x max_len, | |||||
:return: torch.FloatTensor. | |||||
""" | |||||
outputs = self.model(word_pieces) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return self.dropout_layer(outputs) | |||||
def drop_word(self, words): | |||||
""" | |||||
:param torch.LongTensor words: batch_size x max_len | |||||
:return: | |||||
""" | |||||
if self.word_dropout > 0 and self.training: | |||||
with torch.no_grad(): | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
endoftext_mask = words.ne(self._wordpiece_endoftext_index) | |||||
mask = endoftext_mask.__and__(mask) # pad的位置不为unk | |||||
words = words.masked_fill(mask, self._wordpiece_unk_index) | |||||
return words | |||||
def generate_from_str(self, text='', max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0): | |||||
""" | |||||
:param str text: 故事的开头 | |||||
:param int max_len: 生成多长的句子 | |||||
:param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。 | |||||
:param int num_beams: 使用多大的beam size | |||||
:param float temperature: 用以调节采样分布的 | |||||
:param int top_k: 只保留此表中top_k个词进行生成。范围1-infinity | |||||
:param float top_p: 保留概率累积为top_p的词汇,范围0-1. | |||||
:param float repetition_penalty: 对重复token的惩罚 | |||||
:param float length_penalty: 惩罚过长的句子 | |||||
:return: list[str] | |||||
""" | |||||
if len(text)==0: | |||||
word_pieces = torch.LongTensor([[self.model.tokenizer.bos_index]]) | |||||
start_idx = 1 | |||||
else: | |||||
assert isinstance(text, str), "Only string input allowed." | |||||
assert self.language_model, "You must set `language_model=True`." | |||||
word_pieces = self.model.convert_words_to_word_pieces(text, add_prefix_space=True) | |||||
word_pieces = torch.LongTensor([word_pieces]) | |||||
start_idx = 0 | |||||
device = _get_model_device(self) | |||||
word_pieces = word_pieces.to(device) | |||||
outputs = self.model.encoder.generate(input_ids=word_pieces, | |||||
max_length=max_len, | |||||
do_sample=do_sample, | |||||
num_beams=num_beams, | |||||
temperature=temperature, | |||||
top_k=top_k, | |||||
top_p=top_p, | |||||
repetition_penalty=repetition_penalty, | |||||
bos_token_id=self.model.tokenizer.bos_index, | |||||
pad_token_id=self.model.tokenizer.eos_index, # 使用<|endoftext|>代替pad | |||||
eos_token_ids=self.model.tokenizer.eos_index, | |||||
length_penalty=length_penalty).squeeze(0) | |||||
output_strs = [] | |||||
if outputs.dim()==1: | |||||
outputs = outputs[None] | |||||
outputs = outputs[:, start_idx:] | |||||
for i in range(len(outputs)): | |||||
str_ = self.model.tokenizer.convert_tokens_to_string(self.model.tokenizer.convert_ids_to_tokens(outputs[i].tolist())) | |||||
output_strs.append(str_) | |||||
return output_strs | |||||
def generate(self, word_pieces, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0): | |||||
""" | |||||
:param word_pieces: | |||||
:param int max_len: 生成多长的句子 | |||||
:param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。 | |||||
:param int num_beams: 使用多大的beam size | |||||
:param float temperature: 用以调节采样分布的 | |||||
:param int top_k: 只保留此表中top_k个词进行生成。范围1-infinity | |||||
:param float top_p: 保留概率累积为top_p的词汇,范围0-1. | |||||
:param float repetition_penalty: 对重复token的惩罚 | |||||
:param float length_penalty: 惩罚过长的句子 | |||||
:return: | |||||
""" | |||||
pass | |||||
def get_lm_loss(self, release=True): | |||||
""" | |||||
当language_model=True时,可以通过该接口获取当前batch的language model loss的大小 | |||||
:param bool release: 如果为True,获取了lm_loss后在下一次forward完成之前都无法获取lm_loss了 | |||||
:return: torch.FloatTensor([]) | |||||
""" | |||||
if hasattr(self.model, '_lm_loss_value'): | |||||
lm_loss_value = self.model._lm_loss_value | |||||
if release: | |||||
delattr(self.model, '_lm_loss_value') | |||||
return lm_loss_value | |||||
elif self.lm_loss: | |||||
raise RuntimeError("Make sure you have passed a batch into GPT2Embdding before accessing loss.") | |||||
else: | |||||
raise RuntimeError("Initialize your GPT2Embedding with language_model=True.") | |||||
class _GPT2Model(nn.Module): | |||||
def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False, | |||||
only_use_pretrain_bpe=False, min_freq=2, truncate_embed=False): | |||||
super().__init__() | |||||
self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name) | |||||
if language_model: | |||||
self.encoder = GPT2LMHeadModel.from_pretrained(model_dir_or_name) | |||||
else: | |||||
self.encoder = GPT2Model.from_pretrained(model_dir_or_name) | |||||
self.lm_loss = language_model | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = self.encoder.config.n_layer | |||||
if isinstance(layers, list): | |||||
self.layers = [int(l) for l in layers] | |||||
elif isinstance(layers, str): | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | |||||
raise TypeError("`layers` only supports str or list[int]") | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a GPT2 model with {encoder_layer_number} layers." | |||||
else: | |||||
assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a GPT2 model with {encoder_layer_number} layers." | |||||
assert pool_method in ('avg', 'max', 'first', 'last') | |||||
self.pool_method = pool_method | |||||
self.auto_truncate = auto_truncate | |||||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑<s>和</s> | |||||
logger.info("Start to generate word pieces for word.") | |||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||||
word_piece_dict = {'<|endoftext|>': 1} # 用到的word_piece以及新增的 | |||||
found_count = 0 | |||||
new_add_to_bpe_vocab = 0 | |||||
unsegment_count = 0 | |||||
for word, index in vocab: | |||||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||||
word = '<|endoftext|>' | |||||
elif index == vocab.unknown_idx: | |||||
word = '<|endoftext|>' | |||||
# _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容 | |||||
word_pieces = [] | |||||
word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True)) | |||||
if len(word_pieces) == 1: | |||||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||||
if index not in (vocab.unknown_idx, vocab.padding_idx) and word_pieces[0] == '<|endoftext|>': # 说明这个词不在原始的word里面 | |||||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||||
word_piece_dict[word] = 1 # 新增一个值 | |||||
new_add_to_bpe_vocab += 1 | |||||
unsegment_count += 1 | |||||
continue | |||||
for word_piece in word_pieces: | |||||
word_piece_dict[word_piece] = 1 | |||||
found_count += 1 | |||||
if unsegment_count>0: | |||||
if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||||
logger.info(f"{unsegment_count} words are unsegmented.") | |||||
else: | |||||
logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||||
original_embed = self.encoder.get_input_embeddings().weight | |||||
# 特殊词汇要特殊处理 | |||||
if not truncate_embed: # 如果不删除的话需要将已有的加上 | |||||
word_piece_dict.update(self.tokenzier.encoder) | |||||
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||||
new_word_piece_vocab = OrderedDict() | |||||
for index, token in enumerate(['<|endoftext|>']): | |||||
index = word_piece_dict.pop(token, None) | |||||
if index is not None: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]] | |||||
for token in word_piece_dict.keys(): | |||||
if token not in new_word_piece_vocab: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
index = new_word_piece_vocab[token] | |||||
if token in self.tokenzier.encoder: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]] | |||||
else: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.encoder['<|endoftext|>']] | |||||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||||
self.encoder.set_input_embeddings(embed) | |||||
self.encoder.tie_weights() | |||||
self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||||
word_to_wordpieces = [] | |||||
word_pieces_lengths = [] | |||||
for word, index in vocab: | |||||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||||
word = '<|endoftext|>' | |||||
elif index == vocab.unknown_idx: | |||||
word = '<|endoftext|>' | |||||
word_pieces = self.tokenzier.tokenize(word) | |||||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||||
word_to_wordpieces.append(word_pieces) | |||||
word_pieces_lengths.append(len(word_pieces)) | |||||
self._word_pad_index = vocab.padding_idx | |||||
self._endoftext_index = self.tokenzier.encoder.get('<|endoftext|>') | |||||
self._wordpiece_pad_index = self.tokenzier.encoder.get('<|endoftext|>') # 需要用于生成word_piece | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | |||||
logger.debug("Successfully generate word pieces.") | |||||
def forward(self, words): | |||||
""" | |||||
:param words: torch.LongTensor, batch_size x max_len | |||||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||||
""" | |||||
with torch.no_grad(): | |||||
batch_size, max_word_len = words.size() | |||||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word | |||||
seq_len = word_mask.sum(dim=-1) | |||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), | |||||
0) # batch_size x max_len | |||||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | |||||
max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||||
if max_word_piece_length > self._max_position_embeddings: | |||||
if self.auto_truncate: | |||||
word_pieces_lengths = word_pieces_lengths.masked_fill( | |||||
word_pieces_lengths > self._max_position_embeddings, | |||||
self._max_position_embeddings) | |||||
else: | |||||
raise RuntimeError( | |||||
"After split words into word pieces, the lengths of word pieces are longer than the " | |||||
f"maximum allowed sequence length:{self._max_position_embeddings} of GPT2. You can set " | |||||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | |||||
word_pieces = words.new_full((batch_size, min(max_word_piece_length, self._max_position_embeddings)), | |||||
fill_value=self._wordpiece_pad_index) | |||||
word_labels = word_pieces.clone() | |||||
attn_masks = torch.zeros_like(word_pieces) | |||||
# 1. 获取words的word_pieces的id,以及对应的span范围 | |||||
word_indexes = words.cpu().numpy() | |||||
for i in range(batch_size): | |||||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) | |||||
if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings: | |||||
word_pieces_i = word_pieces_i[:self._max_position_embeddings] | |||||
word_pieces[i, :word_pieces_lengths[i]] = torch.LongTensor(word_pieces_i) | |||||
word_labels[i, word_pieces_lengths[i]:].fill_(-100) # 计算lm_loss用的 | |||||
attn_masks[i, :word_pieces_lengths[i]].fill_(1) | |||||
# 添加<|endoftext|>, 默认不添加了 | |||||
# word_pieces[:, 0].fill_(self._endoftext_index) | |||||
batch_indexes = torch.arange(batch_size).to(words) | |||||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||||
if self.lm_loss: | |||||
gpt2_outputs = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, labels=word_labels, | |||||
output_attentions=False) | |||||
gpt2_outputs, self._lm_loss_value = gpt2_outputs[-1], gpt2_outputs[0] # n_layers x batch_size x max_len x hidden_size | |||||
else: | |||||
gpt2_outputs = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, | |||||
output_attentions=False)[-1] | |||||
outputs = gpt2_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | |||||
gpt2_outputs[-1].size(-1)) | |||||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len+1) | |||||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||||
if self.pool_method == 'first': | |||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0) | |||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
elif self.pool_method == 'last': | |||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] - 1 | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0) | |||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
for l_index, l in enumerate(self.layers): | |||||
output_layer = gpt2_outputs[l] | |||||
real_word_piece_length = output_layer.size(1) | |||||
if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 | |||||
paddings = output_layer.new_zeros(batch_size, | |||||
max_word_piece_length - real_word_piece_length, | |||||
output_layer.size(2)) | |||||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | |||||
# 从word_piece collapse到word的表示 | |||||
# truncate_output_layer = output_layer # 删除endoftext batch_size x len x hidden_size | |||||
if self.pool_method == 'first': | |||||
tmp = output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||||
outputs[l_index, :, :batch_word_pieces_cum_length.size(1)] = tmp | |||||
elif self.pool_method == 'last': | |||||
tmp = output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||||
outputs[l_index, :, :batch_word_pieces_cum_length.size(1)] = tmp | |||||
elif self.pool_method == 'max': | |||||
for i in range(batch_size): | |||||
for j in range(seq_len[i]): | |||||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||||
outputs[l_index, i, j], _ = torch.max(output_layer[i, start:end], dim=-2) | |||||
else: | |||||
for i in range(batch_size): | |||||
for j in range(seq_len[i]): | |||||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||||
outputs[l_index, i, j] = torch.mean(output_layer[i, start:end], dim=-2) | |||||
# 3. 最终的embedding结果 | |||||
return outputs | |||||
def get_lm_loss(self): | |||||
""" | |||||
当language_model为True时,通过该接口可以获取最近传入的一个batch的lanuage model loss | |||||
:return: | |||||
""" | |||||
return self._lm_loss_value | |||||
class _GPT2WordPieceModel(nn.Module): | |||||
""" | |||||
这个模块用于直接计算word_piece的结果. | |||||
""" | |||||
def __init__(self, model_dir_or_name: str, layers: str = '-1', language_model: bool=False): | |||||
super().__init__() | |||||
self.tokenizer = GPT2Tokenizer.from_pretrained(model_dir_or_name) | |||||
if language_model: | |||||
self.encoder = GPT2LMHeadModel.from_pretrained(model_dir_or_name) | |||||
else: | |||||
self.encoder = GPT2Model.from_pretrained(model_dir_or_name) | |||||
self.lm_loss = language_model | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = self.encoder.config.n_layer | |||||
if isinstance(layers, list): | |||||
self.layers = [int(l) for l in layers] | |||||
elif isinstance(layers, str): | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | |||||
raise TypeError("`layers` only supports str or list[int]") | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a gpt2 model with {encoder_layer_number} layers." | |||||
else: | |||||
assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a gpt2 model with {encoder_layer_number} layers." | |||||
self._endoftext_index = self.tokenizer.encoder.get('<|endoftext|>') | |||||
self._wordpiece_pad_index = self.tokenizer.encoder.get('<|endoftext|>') # 原来并没有pad,使用这个值替代一下。这个pad值并不重要,因为是从左到右计算的 | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||||
def index_datasets(self, *datasets, field_name, add_endoftext=False, add_prefix_space=True): | |||||
""" | |||||
使用gpt2的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果开头不是<|endoftext|>, 且将 | |||||
word_pieces这一列的pad value设置为了bert的pad value。 | |||||
:param datasets: DataSet对象 | |||||
:param field_name: 基于哪一列index | |||||
:param bool add_prefix_space: 是否添加句首的空格 | |||||
:return: | |||||
""" | |||||
convert_words_to_word_pieces = partial(self.convert_words_to_word_pieces, add_endoftext=add_endoftext, | |||||
add_prefix_space=add_prefix_space) | |||||
for index, dataset in enumerate(datasets): | |||||
try: | |||||
dataset.apply_field(convert_words_to_word_pieces, field_name=field_name, new_field_name='word_pieces', | |||||
is_input=True) | |||||
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when processing the {index} dataset.") | |||||
raise e | |||||
def convert_words_to_word_pieces(self, words, add_endoftext=False, add_prefix_space=True): | |||||
""" | |||||
:param list[str],str words: 将str数据转换为index | |||||
:param bool add_endoftext: 是否在句首增加endoftext | |||||
:param bool add_prefix_space: 是否添加句首的空格 | |||||
:return: | |||||
""" | |||||
word_pieces = [] | |||||
if isinstance(words, str): | |||||
words = self.tokenizer.tokenize(words, add_prefix_space=add_prefix_space) | |||||
word_piece_ids = self.tokenizer.convert_tokens_to_ids(words) | |||||
word_pieces.extend(word_piece_ids) | |||||
else: | |||||
for word in words: | |||||
tokens = self.tokenizer.tokenize(word, add_prefix_space=add_prefix_space) | |||||
word_piece_ids = self.tokenizer.convert_tokens_to_ids(tokens) | |||||
word_pieces.extend(word_piece_ids) | |||||
if add_endoftext: | |||||
if word_pieces[0] != self._endoftext_index: | |||||
word_pieces.insert(0, self._endoftext_index) | |||||
if len(word_pieces) > self._max_position_embeddings: | |||||
word_pieces[self._max_position_embeddings - 1] = word_pieces[-1] | |||||
word_pieces = word_pieces[:self._max_position_embeddings] | |||||
return word_pieces | |||||
def forward(self, word_pieces, token_type_ids=None): | |||||
""" | |||||
:param word_pieces: torch.LongTensor, batch_size x max_len | |||||
:param token_type_ids: torch.LongTensor, batch_size x max_len | |||||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||||
""" | |||||
batch_size, max_len = word_pieces.size() | |||||
attn_masks = word_pieces.ne(self._wordpiece_pad_index) # 可能会错误导致开头的词被mask掉 | |||||
word_pieces = word_pieces.masked_fill(attn_masks.eq(0), self._endoftext_index) # 替换pad的值 | |||||
if self.lm_loss: | |||||
labels = word_pieces.clone() | |||||
labels = labels.masked_fill(labels.eq(self._wordpiece_pad_index), -100) | |||||
gpt_outputs = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||||
output_attentions=False, labels=labels) | |||||
gpt_outputs, self._lm_loss_value = gpt_outputs[-1], gpt_outputs[0] # n_layers x batch_size x max_len x hidden_size | |||||
else: | |||||
gpt_outputs = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||||
output_attentions=False) | |||||
gpt_outputs = gpt_outputs[-1] | |||||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||||
outputs = gpt_outputs[0].new_zeros((len(self.layers), batch_size, max_len, gpt_outputs[0].size(-1))) | |||||
for l_index, l in enumerate(self.layers): | |||||
outputs[l_index] = gpt_outputs[l] # 删除开头 | |||||
return outputs | |||||
def get_lm_loss(self): | |||||
""" | |||||
当language_model为True时,通过该接口可以获取最近传入的一个batch的lanuage model loss | |||||
:return: | |||||
""" | |||||
return self._lm_loss_value | |||||
@@ -0,0 +1,538 @@ | |||||
r""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"RobertaEmbedding", | |||||
"RobertaWordPieceEncoder" | |||||
] | |||||
from functools import partial | |||||
import collections | |||||
import warnings | |||||
from itertools import chain | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from .contextual_embedding import ContextualEmbedding | |||||
from ..core import logger, Vocabulary | |||||
from ..modules.encoder.roberta import RobertaModel | |||||
from ..modules.tokenizer import RobertaTokenizer | |||||
class RobertaEmbedding(ContextualEmbedding): | |||||
r""" | |||||
使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | |||||
预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有RobertaEmbedding在输入word | |||||
时切分),在分割之后长度可能会超过最大长度限制。 | |||||
RobertaEmbedding可以支持自动下载权重,当前支持的模型: | |||||
en: roberta-base | |||||
en-large: roberta-large | |||||
Example:: | |||||
>>> import torch | |||||
>>> from fastNLP import Vocabulary | |||||
>>> from fastNLP.embeddings import RobertaEmbedding | |||||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||||
>>> embed = RobertaEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') | |||||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||||
>>> outputs = embed(words) | |||||
>>> outputs.size() | |||||
>>> # torch.Size([1, 5, 2304]) | |||||
""" | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | |||||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | |||||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs): | |||||
r""" | |||||
:param ~fastNLP.Vocabulary vocab: 词表 | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件 | |||||
(以vocab.json作为后缀名), 权重文件(以.bin作为文件后缀名), 配置文件(以config.json作为后缀名)。 | |||||
:param str,list layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 | |||||
从0开始,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding, position embedding) | |||||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | |||||
中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | |||||
会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的 | |||||
embedding长度不匹配。 | |||||
:param bool pooled_cls: 返回的<s>是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取<s>做预测, | |||||
一般该值为True。 | |||||
:param bool requires_grad: 是否需要gradient以更新Bert的权重。 | |||||
:param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 | |||||
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | |||||
来进行分类的任务将auto_truncate置为True。 | |||||
:param kwargs: | |||||
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||||
建议设置为True。 | |||||
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 | |||||
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||||
""" | |||||
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||||
if word_dropout > 0: | |||||
assert vocab.unknown is not None, "When word_drop > 0, Vocabulary must contain the unknown token." | |||||
self._word_sep_index = -100 | |||||
if '</s>' in vocab: | |||||
self._word_sep_index = vocab['</s>'] | |||||
self._word_cls_index = -100 | |||||
if '<s>' in vocab: | |||||
self._word_cls_index = vocab['<s>'] | |||||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||||
truncate_embed = kwargs.get('truncate_embed', True) | |||||
min_freq = kwargs.get('min_freq', 2) | |||||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||||
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | |||||
self.requires_grad = requires_grad | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||||
def _delete_model_weights(self): | |||||
del self.model | |||||
def forward(self, words): | |||||
r""" | |||||
计算words的roberta embedding表示。计算之前会在每句话的开始增加<s>在结束增加</s>, 并根据include_cls_sep判断要不要 | |||||
删除这两个token的表示。 | |||||
:param torch.LongTensor words: [batch_size, max_len] | |||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||||
""" | |||||
words = self.drop_word(words) | |||||
outputs = self._get_sent_reprs(words) | |||||
if outputs is not None: | |||||
return self.dropout(outputs) | |||||
outputs = self.model(words) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return self.dropout(outputs) | |||||
def drop_word(self, words): | |||||
r""" | |||||
按照设定随机将words设置为unknown_index。 | |||||
:param torch.LongTensor words: batch_size x max_len | |||||
:return: | |||||
""" | |||||
if self.word_dropout > 0 and self.training: | |||||
with torch.no_grad(): | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
pad_mask = words.ne(self._word_pad_index) | |||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||||
if self._word_sep_index!=-100: | |||||
not_sep_mask = words.ne(self._word_sep_index) | |||||
mask = mask.__and__(not_sep_mask) | |||||
if self._word_cls_index!=-100: | |||||
not_cls_mask = words.ne(self._word_cls_index) | |||||
mask = mask.__and__(not_cls_mask) | |||||
words = words.masked_fill(mask, self._word_unk_index) | |||||
return words | |||||
class _RobertaWordModel(nn.Module): | |||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||||
only_use_pretrain_bpe=False, truncate_embed=True): | |||||
super().__init__() | |||||
self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name) | |||||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
if isinstance(layers, list): | |||||
self.layers = [int(l) for l in layers] | |||||
elif isinstance(layers, str): | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | |||||
raise TypeError("`layers` only supports str or list[int]") | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a roberta model with {encoder_layer_number} layers." | |||||
else: | |||||
assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a roberta model with {encoder_layer_number} layers." | |||||
assert pool_method in ('avg', 'max', 'first', 'last') | |||||
self.pool_method = pool_method | |||||
self.include_cls_sep = include_cls_sep | |||||
self.pooled_cls = pooled_cls | |||||
self.auto_truncate = auto_truncate | |||||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑<s>和</s> | |||||
logger.info("Start to generate word pieces for word.") | |||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||||
word_piece_dict = {'<s>': 1, '</s>': 1} # 用到的word_piece以及新增的 | |||||
found_count = 0 | |||||
new_add_to_bpe_vocab = 0 | |||||
unsegment_count = 0 | |||||
if "<s>" in vocab: | |||||
warnings.warn("<s> detected in your vocabulary. RobertaEmbedding will add <s> and </s> to the begin " | |||||
"and end of the input automatically, make sure you don't add <s> and </s> at the begin" | |||||
" and end.") | |||||
for word, index in vocab: | |||||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||||
word = '<pad>' | |||||
elif index == vocab.unknown_idx: | |||||
word = '<unk>' | |||||
# _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容 | |||||
word_pieces = [] | |||||
# 如果这个word不是在句子开头 | |||||
word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True)) | |||||
if len(word_pieces) == 1: | |||||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||||
if index != vocab.unknown_idx and word_pieces[0] == '<unk>': # 说明这个词不在原始的word里面 | |||||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||||
word_piece_dict[word] = 1 # 新增一个值 | |||||
new_add_to_bpe_vocab += 1 | |||||
unsegment_count += 1 | |||||
continue | |||||
found_count += 1 | |||||
for word_piece in word_pieces: | |||||
word_piece_dict[word_piece] = 1 | |||||
# 如果这个word是在句子开头 | |||||
original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||||
# 特殊词汇要特殊处理 | |||||
if not truncate_embed: # 如果不删除的话需要将已有的加上 | |||||
word_piece_dict.update(self.tokenzier.encoder) | |||||
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||||
new_word_piece_vocab = collections.OrderedDict() | |||||
for index, token in enumerate(['<s>', '<pad>', '</s>', '<unk>']): | |||||
index = word_piece_dict.pop(token, None) | |||||
if index is not None: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]] | |||||
for token in word_piece_dict.keys(): | |||||
if token not in new_word_piece_vocab: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
index = new_word_piece_vocab[token] | |||||
if token in self.tokenzier.encoder: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]] | |||||
else: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.encoder['<unk>']] | |||||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||||
self.encoder.embeddings.word_embeddings = embed | |||||
self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||||
if unsegment_count>0: | |||||
if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||||
logger.info(f"{unsegment_count} words are unsegmented.") | |||||
else: | |||||
logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||||
word_to_wordpieces = [] | |||||
word_pieces_lengths = [] | |||||
for word, index in vocab: | |||||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||||
word = '<pad>' | |||||
elif index == vocab.unknown_idx: | |||||
word = '<unk>' | |||||
word_pieces = self.tokenzier.tokenize(word) | |||||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||||
word_to_wordpieces.append(word_pieces) | |||||
word_pieces_lengths.append(len(word_pieces)) | |||||
self._cls_index = self.tokenzier.encoder['<s>'] | |||||
self._sep_index = self.tokenzier.encoder['</s>'] | |||||
self._word_pad_index = vocab.padding_idx | |||||
self._wordpiece_pad_index = self.tokenzier.encoder['<pad>'] # 需要用于生成word_piece | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | |||||
logger.debug("Successfully generate word pieces.") | |||||
def forward(self, words): | |||||
r""" | |||||
:param words: torch.LongTensor, batch_size x max_len | |||||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||||
""" | |||||
with torch.no_grad(): | |||||
batch_size, max_word_len = words.size() | |||||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word | |||||
seq_len = word_mask.sum(dim=-1) | |||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), 0) # batch_size x max_len | |||||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | |||||
max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||||
if max_word_piece_length + 2 > self._max_position_embeddings: | |||||
if self.auto_truncate: | |||||
word_pieces_lengths = word_pieces_lengths.masked_fill( | |||||
word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2) | |||||
else: | |||||
raise RuntimeError( | |||||
"After split words into word pieces, the lengths of word pieces are longer than the " | |||||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " | |||||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | |||||
# +2是由于需要加入<s>与</s> | |||||
word_pieces = words.new_full((batch_size, min(max_word_piece_length + 2, self._max_position_embeddings)), | |||||
fill_value=self._wordpiece_pad_index) | |||||
attn_masks = torch.zeros_like(word_pieces) | |||||
# 1. 获取words的word_pieces的id,以及对应的span范围 | |||||
word_indexes = words.cpu().numpy() | |||||
for i in range(batch_size): | |||||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) | |||||
if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2: | |||||
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2] | |||||
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i) | |||||
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1) | |||||
word_pieces[:, 0].fill_(self._cls_index) | |||||
batch_indexes = torch.arange(batch_size).to(words) | |||||
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index | |||||
token_type_ids = torch.zeros_like(word_pieces) | |||||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, | |||||
attention_mask=attn_masks, | |||||
output_all_encoded_layers=True) | |||||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | |||||
if self.include_cls_sep: | |||||
s_shift = 1 | |||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||||
bert_outputs[-1].size(-1)) | |||||
else: | |||||
s_shift = 0 | |||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | |||||
bert_outputs[-1].size(-1)) | |||||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | |||||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||||
if self.pool_method == 'first': | |||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0) | |||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
elif self.pool_method == 'last': | |||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max() + 1] - 1 | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0) | |||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
for l_index, l in enumerate(self.layers): | |||||
output_layer = bert_outputs[l] | |||||
real_word_piece_length = output_layer.size(1) - 2 | |||||
if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 | |||||
paddings = output_layer.new_zeros(batch_size, | |||||
max_word_piece_length - real_word_piece_length, | |||||
output_layer.size(2)) | |||||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | |||||
# 从word_piece collapse到word的表示 | |||||
truncate_output_layer = output_layer[:, 1:-1] # 删除<s>与</s> batch_size x len x hidden_size | |||||
if self.pool_method == 'first': | |||||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp | |||||
elif self.pool_method == 'last': | |||||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp | |||||
elif self.pool_method == 'max': | |||||
for i in range(batch_size): | |||||
for j in range(seq_len[i]): | |||||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||||
outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2) | |||||
else: | |||||
for i in range(batch_size): | |||||
for j in range(seq_len[i]): | |||||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||||
outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | |||||
if self.include_cls_sep: | |||||
if l in (len(bert_outputs) - 1, -1) and self.pooled_cls: | |||||
outputs[l_index, :, 0] = pooled_cls | |||||
else: | |||||
outputs[l_index, :, 0] = output_layer[:, 0] | |||||
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift] | |||||
# 3. 最终的embedding结果 | |||||
return outputs | |||||
class RobertaWordPieceEncoder(nn.Module): | |||||
r""" | |||||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | |||||
BertWordPieceEncoder可以支持自动下载权重,当前支持的模型: | |||||
en: roberta-base | |||||
en-large: roberta-large | |||||
""" | |||||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||||
r""" | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||||
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding, | |||||
position embedding) | |||||
:param bool pooled_cls: 返回的句子开头的<s>是否使用预训练中的BertPool映射一下。如果下游任务取<s>做预测,一般该值为True。 | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||||
:param bool requires_grad: 是否需要gradient。 | |||||
""" | |||||
super().__init__() | |||||
self.model = _WordPieceRobertaModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | |||||
self._sep_index = self.model._sep_index | |||||
self._cls_index = self.model._cls_index | |||||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | |||||
self._wordpiece_unk_index = self.model._wordpiece_unknown_index | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||||
self.requires_grad = requires_grad | |||||
self.word_dropout = word_dropout | |||||
self.dropout_layer = nn.Dropout(dropout) | |||||
@property | |||||
def embed_size(self): | |||||
return self._embed_size | |||||
@property | |||||
def embedding_dim(self): | |||||
return self._embed_size | |||||
@property | |||||
def num_embedding(self): | |||||
return self.model.encoder.config.vocab_size | |||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True, add_prefix_space=True): | |||||
r""" | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | |||||
bert的pad value。 | |||||
:param ~fastNLP.DataSet datasets: DataSet对象 | |||||
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||||
:param bool add_cls_sep: 如果首尾不是<s>与</s>会在首尾额外加入<s>与</s>。 | |||||
:param bool add_prefix_spance: 是否在句首添加额外的空格,RoBERTa预训练时该值为True | |||||
:return: | |||||
""" | |||||
self.model.index_datasets(*datasets, field_name=field_name, add_cls_sep=add_cls_sep, add_prefix_space=add_prefix_space) | |||||
def forward(self, word_pieces, token_type_ids=None): | |||||
r""" | |||||
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | |||||
:param words: batch_size x max_len | |||||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入), | |||||
第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。 | |||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||||
""" | |||||
word_pieces = self.drop_word(word_pieces) | |||||
outputs = self.model(word_pieces) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return self.dropout_layer(outputs) | |||||
def drop_word(self, words): | |||||
r""" | |||||
按照设定随机将words设置为unknown_index。 | |||||
:param torch.LongTensor words: batch_size x max_len | |||||
:return: | |||||
""" | |||||
if self.word_dropout > 0 and self.training: | |||||
with torch.no_grad(): | |||||
not_sep_mask = words.ne(self._sep_index) | |||||
not_cls_mask = words.ne(self._cls_index) | |||||
replaceable_mask = not_sep_mask.__and__(not_cls_mask) | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
pad_mask = words.ne(self._wordpiece_pad_index) | |||||
mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk | |||||
words = words.masked_fill(mask, self._wordpiece_unk_index) | |||||
return words | |||||
class _WordPieceRobertaModel(nn.Module): | |||||
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | |||||
super().__init__() | |||||
self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name) | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
if isinstance(layers, list): | |||||
self.layers = [int(l) for l in layers] | |||||
elif isinstance(layers, str): | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | |||||
raise TypeError("`layers` only supports str or list[int]") | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a RoBERTa model with {encoder_layer_number} layers." | |||||
else: | |||||
assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a RoBERTa model with {encoder_layer_number} layers." | |||||
self._cls_index = self.tokenzier.encoder['<s>'] | |||||
self._sep_index = self.tokenzier.encoder['</s>'] | |||||
self._wordpiece_pad_index = self.tokenzier.encoder['<pad>'] # 需要用于生成word_piece | |||||
self._wordpiece_unknown_index = self.tokenzier.encoder['<unk>'] | |||||
self.pooled_cls = pooled_cls | |||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True, add_prefix_space=True): | |||||
r""" | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||||
:param datasets: DataSet对象 | |||||
:param field_name: 基于哪一列index | |||||
:param bool add_cls_sep: 是否在句首句尾添加cls和sep的index | |||||
:param bool add_prefix_space: 是否在句子开头添加空格,预训练时RoBERTa该值为True | |||||
:return: | |||||
""" | |||||
encode_func = partial(self.tokenzier.encode, add_special_tokens=add_cls_sep, add_prefix_space=add_prefix_space) | |||||
for index, dataset in enumerate(datasets): | |||||
try: | |||||
dataset.apply_field(encode_func, field_name=field_name, new_field_name='word_pieces', | |||||
is_input=True) | |||||
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when processing the {index} dataset.") | |||||
raise e | |||||
def forward(self, word_pieces): | |||||
r""" | |||||
:param word_pieces: torch.LongTensor, batch_size x max_len | |||||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||||
""" | |||||
batch_size, max_len = word_pieces.size() | |||||
attn_masks = word_pieces.ne(self._wordpiece_pad_index) | |||||
roberta_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=torch.zeros_like(word_pieces), | |||||
attention_mask=attn_masks, | |||||
output_all_encoded_layers=True) | |||||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||||
outputs = roberta_outputs[0].new_zeros((len(self.layers), batch_size, max_len, roberta_outputs[0].size(-1))) | |||||
for l_index, l in enumerate(self.layers): | |||||
roberta_output = roberta_outputs[l] | |||||
if l in (len(roberta_output)-1, -1) and self.pooled_cls: | |||||
roberta_output[:, 0] = pooled_cls | |||||
outputs[l_index] = roberta_output | |||||
return outputs |
@@ -19,7 +19,7 @@ from .embedding import TokenEmbedding | |||||
from ..core import logger | from ..core import logger | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | ||||
from ..modules.utils import _get_file_name_base_on_postfix | |||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||||
class StaticEmbedding(TokenEmbedding): | class StaticEmbedding(TokenEmbedding): | ||||
@@ -48,6 +48,18 @@ PRETRAINED_BERT_MODEL_DIR = { | |||||
'cn-wwm-ext': "bert-chinese-wwm-ext.zip" | 'cn-wwm-ext': "bert-chinese-wwm-ext.zip" | ||||
} | } | ||||
PRETRAINED_GPT2_MODEL_DIR = { | |||||
'en': 'gpt2.zip', | |||||
'en-medium': 'gpt2-medium.zip', | |||||
'en-large': 'gpt2-large.zip', | |||||
'en-xl': 'gpt2-xl.zip' | |||||
} | |||||
PRETRAINED_ROBERTA_MODEL_DIR = { | |||||
'en': 'roberta-base.zip', | |||||
'en-large': 'roberta-large.zip' | |||||
} | |||||
PRETRAINED_ELMO_MODEL_DIR = { | PRETRAINED_ELMO_MODEL_DIR = { | ||||
'en': 'elmo_en_Medium.zip', | 'en': 'elmo_en_Medium.zip', | ||||
'en-small': "elmo_en_Small.zip", | 'en-small': "elmo_en_Small.zip", | ||||
@@ -127,14 +139,18 @@ DATASET_DIR = { | |||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | ||||
"bert": PRETRAINED_BERT_MODEL_DIR, | "bert": PRETRAINED_BERT_MODEL_DIR, | ||||
"static": PRETRAIN_STATIC_FILES} | |||||
"static": PRETRAIN_STATIC_FILES, | |||||
'gpt2': PRETRAINED_GPT2_MODEL_DIR, | |||||
'roberta': PRETRAINED_ROBERTA_MODEL_DIR} | |||||
# 用于扩展fastNLP的下载 | # 用于扩展fastNLP的下载 | ||||
FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' | FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' | ||||
FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | ||||
'bert':'fastnlp_bert_url.txt', | |||||
'static': 'fastnlp_static_url.txt' | |||||
} | |||||
'bert':'fastnlp_bert_url.txt', | |||||
'static': 'fastnlp_static_url.txt', | |||||
'gpt2': 'fastnlp_gpt2_url.txt', | |||||
'roberta': 'fastnlp_roberta_url.txt' | |||||
} | |||||
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | ||||
@@ -273,7 +289,7 @@ def _get_embedding_url(embed_type, name): | |||||
return url | return url | ||||
raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) | raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) | ||||
else: | else: | ||||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | |||||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static, gpt2, roberta") | |||||
def _read_extend_url_file(filename, name)->str: | def _read_extend_url_file(filename, name)->str: | ||||
r""" | r""" | ||||
@@ -281,7 +297,7 @@ def _read_extend_url_file(filename, name)->str: | |||||
:param str filename: 在默认的路径下寻找file这个文件 | :param str filename: 在默认的路径下寻找file这个文件 | ||||
:param str name: 需要寻找的资源的名称 | :param str name: 需要寻找的资源的名称 | ||||
:return: str or None | |||||
:return: str,None | |||||
""" | """ | ||||
cache_dir = get_cache_path() | cache_dir = get_cache_path() | ||||
filepath = os.path.join(cache_dir, filename) | filepath = os.path.join(cache_dir, filename) | ||||
@@ -488,3 +504,57 @@ def match_file(dir_name: str, cache_dir: Path) -> str: | |||||
return matched_filenames[-1] | return matched_filenames[-1] | ||||
else: | else: | ||||
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | ||||
def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): | |||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||||
else: | |||||
logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") | |||||
raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") | |||||
return str(model_dir) | |||||
def _get_gpt2_dir(model_dir_or_name: str = 'en'): | |||||
if model_dir_or_name.lower() in PRETRAINED_GPT2_MODEL_DIR: | |||||
model_url = _get_embedding_url('gpt2', model_dir_or_name.lower()) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||||
else: | |||||
logger.error(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") | |||||
raise ValueError(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") | |||||
return str(model_dir) | |||||
def _get_roberta_dir(model_dir_or_name: str = 'en'): | |||||
if model_dir_or_name.lower() in PRETRAINED_ROBERTA_MODEL_DIR: | |||||
model_url = _get_embedding_url('roberta', model_dir_or_name.lower()) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||||
else: | |||||
logger.error(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.") | |||||
raise ValueError(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.") | |||||
return str(model_dir) | |||||
def _get_file_name_base_on_postfix(dir_path, postfix): | |||||
r""" | |||||
在dir_path中寻找后缀为postfix的文件. | |||||
:param dir_path: str, 文件夹 | |||||
:param postfix: 形如".bin", ".json"等 | |||||
:return: str,文件的路径 | |||||
""" | |||||
files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) | |||||
if len(files) == 0: | |||||
raise FileNotFoundError(f"There is no file endswith {postfix} file in {dir_path}") | |||||
elif len(files) > 1: | |||||
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") | |||||
return os.path.join(dir_path, files[0]) |
@@ -49,7 +49,16 @@ __all__ = [ | |||||
"TimestepDropout", | "TimestepDropout", | ||||
'summary' | |||||
'summary', | |||||
"BertTokenizer", | |||||
"BertModel", | |||||
"RobertaTokenizer", | |||||
"RobertaModel", | |||||
"GPT2Model", | |||||
"GPT2Tokenizer" | |||||
] | ] | ||||
import sys | import sys | ||||
@@ -61,5 +70,6 @@ from .dropout import TimestepDropout | |||||
from .encoder import * | from .encoder import * | ||||
from .utils import summary | from .utils import summary | ||||
from ..doc_utils import doc_process | from ..doc_utils import doc_process | ||||
from .tokenizer import * | |||||
doc_process(sys.modules[__name__]) | doc_process(sys.modules[__name__]) |
@@ -0,0 +1,109 @@ | |||||
# coding=utf-8 | |||||
__all__ = [ | |||||
"TransformerPast", | |||||
"Past", | |||||
"Decoder" | |||||
] | |||||
import torch | |||||
from torch import nn | |||||
import abc | |||||
import torch.nn.functional as F | |||||
from ...embeddings import StaticEmbedding | |||||
import numpy as np | |||||
from typing import Union, Tuple | |||||
from ...embeddings.utils import get_embeddings | |||||
from torch.nn import LayerNorm | |||||
import math | |||||
class Past: | |||||
def __init__(self): | |||||
pass | |||||
@abc.abstractmethod | |||||
def num_samples(self): | |||||
pass | |||||
@abc.abstractmethod | |||||
def reorder_past(self, indices: torch.LongTensor): | |||||
""" | |||||
根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||||
:param torch.LongTensor indices: | |||||
:param Past past: | |||||
:return: | |||||
""" | |||||
raise NotImplemented | |||||
class TransformerPast(Past): | |||||
def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None, | |||||
num_decoder_layer: int = 6): | |||||
""" | |||||
:param encoder_outputs: (batch,src_seq_len,dim) | |||||
:param encoder_mask: (batch,src_seq_len) | |||||
:param encoder_key: list of (batch, src_seq_len, dim) | |||||
:param encoder_value: | |||||
:param decoder_prev_key: | |||||
:param decoder_prev_value: | |||||
""" | |||||
super().__init__() | |||||
self.encoder_outputs = encoder_outputs | |||||
self.encoder_mask = encoder_mask | |||||
self.encoder_key = [None] * num_decoder_layer | |||||
self.encoder_value = [None] * num_decoder_layer | |||||
self.decoder_prev_key = [None] * num_decoder_layer | |||||
self.decoder_prev_value = [None] * num_decoder_layer | |||||
def num_samples(self): | |||||
if self.encoder_outputs is not None: | |||||
return self.encoder_outputs.size(0) | |||||
return None | |||||
def _reorder_state(self, state, indices): | |||||
if type(state) == torch.Tensor: | |||||
state = state.index_select(index=indices, dim=0) | |||||
elif type(state) == list: | |||||
for i in range(len(state)): | |||||
assert state[i] is not None | |||||
state[i] = state[i].index_select(index=indices, dim=0) | |||||
else: | |||||
raise ValueError('State does not support other format') | |||||
return state | |||||
def reorder_past(self, indices: torch.LongTensor): | |||||
self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices) | |||||
self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||||
self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||||
self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||||
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||||
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||||
return self | |||||
class Decoder(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
@abc.abstractmethod | |||||
def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||||
""" | |||||
当模型进行解码时,使用这个函数。返回一个batch_size x vocab_size的结果与更新的Past状态。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | |||||
解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态。 | |||||
:return: tensor:batch_size x vocab_size, past: Past | |||||
""" | |||||
raise NotImplemented | |||||
@abc.abstractmethod | |||||
def reorder_past(self, indices: torch.LongTensor, past: Past): | |||||
""" | |||||
根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||||
:param torch.LongTensor indices: | |||||
:param Past past: | |||||
:return: | |||||
""" | |||||
raise NotImplemented |
@@ -30,10 +30,18 @@ __all__ = [ | |||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
"BiAttention", | "BiAttention", | ||||
"SelfAttention", | "SelfAttention", | ||||
"BertModel", | |||||
"RobertaModel", | |||||
"GPT2Model" | |||||
] | ] | ||||
from .attention import MultiHeadAttention, BiAttention, SelfAttention | from .attention import MultiHeadAttention, BiAttention, SelfAttention | ||||
from .bert import BertModel | from .bert import BertModel | ||||
from .roberta import RobertaModel | |||||
from .gpt2 import GPT2Model | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | ||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .lstm import LSTM | from .lstm import LSTM | ||||
@@ -4,26 +4,23 @@ r"""undocumented | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"BertModel" | |||||
"BertModel", | |||||
] | ] | ||||
import collections | |||||
import copy | import copy | ||||
import json | import json | ||||
import math | import math | ||||
import os | |||||
import unicodedata | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
import numpy as np | import numpy as np | ||||
from ..utils import _get_file_name_base_on_postfix | |||||
from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||||
from ...io.file_utils import _get_bert_dir | |||||
from ...core import logger | from ...core import logger | ||||
CONFIG_FILE = 'bert_config.json' | CONFIG_FILE = 'bert_config.json' | ||||
VOCAB_NAME = 'vocab.txt' | |||||
BERT_KEY_RENAME_MAP_1 = { | BERT_KEY_RENAME_MAP_1 = { | ||||
'gamma': 'weight', | 'gamma': 'weight', | ||||
@@ -152,33 +149,22 @@ def swish(x): | |||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | ||||
def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): | |||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||||
else: | |||||
logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") | |||||
raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") | |||||
return str(model_dir) | |||||
class BertLayerNorm(nn.Module): | |||||
def __init__(self, hidden_size, eps=1e-12): | |||||
r"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||||
""" | |||||
super(BertLayerNorm, self).__init__() | |||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
self.variance_epsilon = eps | |||||
# class BertLayerNorm(nn.Module): | |||||
# def __init__(self, hidden_size, eps=1e-12): | |||||
# r"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||||
# """ | |||||
# super(BertLayerNorm, self).__init__() | |||||
# self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
# self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
# self.variance_epsilon = eps | |||||
# | |||||
# def forward(self, x): | |||||
# u = x.mean(-1, keepdim=True) | |||||
# s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
# x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
# return self.weight * x + self.bias | |||||
def forward(self, x): | |||||
u = x.mean(-1, keepdim=True) | |||||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
return self.weight * x + self.bias | |||||
BertLayerNorm = torch.nn.LayerNorm | |||||
class DistilBertEmbeddings(nn.Module): | class DistilBertEmbeddings(nn.Module): | ||||
@@ -245,14 +231,18 @@ class BertEmbeddings(nn.Module): | |||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
def forward(self, input_ids, token_type_ids=None): | |||||
def forward(self, input_ids, token_type_ids=None, position_ids=None, words_embeddings=None): | |||||
seq_length = input_ids.size(1) | seq_length = input_ids.size(1) | ||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||||
if position_ids is None: | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||||
if token_type_ids is None: | if token_type_ids is None: | ||||
token_type_ids = torch.zeros_like(input_ids) | token_type_ids = torch.zeros_like(input_ids) | ||||
words_embeddings = self.word_embeddings(input_ids) | |||||
if words_embeddings is None: | |||||
words_embeddings = self.word_embeddings(input_ids) | |||||
else: | |||||
assert input_ids.size() == words_embeddings.size()[: -1] | |||||
position_embeddings = self.position_embeddings(position_ids) | position_embeddings = self.position_embeddings(position_ids) | ||||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | token_type_embeddings = self.token_type_embeddings(token_type_ids) | ||||
@@ -514,6 +504,7 @@ class BertModel(nn.Module): | |||||
pooled_output = sequence_output[:, 0] | pooled_output = sequence_output[:, 0] | ||||
if not output_all_encoded_layers: | if not output_all_encoded_layers: | ||||
encoded_layers = encoded_layers[-1] | encoded_layers = encoded_layers[-1] | ||||
encoded_layers.insert(0, embedding_output) | |||||
return encoded_layers, pooled_output | return encoded_layers, pooled_output | ||||
@classmethod | @classmethod | ||||
@@ -610,436 +601,3 @@ class BertModel(nn.Module): | |||||
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | ||||
return model | return model | ||||
def whitespace_tokenize(text): | |||||
r"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
text = text.strip() | |||||
if not text: | |||||
return [] | |||||
tokens = text.split() | |||||
return tokens | |||||
class WordpieceTokenizer(object): | |||||
r"""Runs WordPiece tokenization.""" | |||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): | |||||
self.vocab = vocab | |||||
self.unk_token = unk_token | |||||
self.max_input_chars_per_word = max_input_chars_per_word | |||||
def tokenize(self, text): | |||||
r"""Tokenizes a piece of text into its word pieces. | |||||
This uses a greedy longest-match-first algorithm to perform tokenization | |||||
using the given vocabulary. | |||||
For example: | |||||
input = "unaffable" | |||||
output = ["un", "##aff", "##able"] | |||||
Args: | |||||
text: A single token or whitespace separated tokens. This should have | |||||
already been passed through `BasicTokenizer`. | |||||
Returns: | |||||
A list of wordpiece tokens. | |||||
""" | |||||
output_tokens = [] | |||||
for token in whitespace_tokenize(text): | |||||
chars = list(token) | |||||
if len(chars) > self.max_input_chars_per_word: | |||||
output_tokens.append(self.unk_token) | |||||
continue | |||||
is_bad = False | |||||
start = 0 | |||||
sub_tokens = [] | |||||
while start < len(chars): | |||||
end = len(chars) | |||||
cur_substr = None | |||||
while start < end: | |||||
substr = "".join(chars[start:end]) | |||||
if start > 0: | |||||
substr = "##" + substr | |||||
if substr in self.vocab: | |||||
cur_substr = substr | |||||
break | |||||
end -= 1 | |||||
if cur_substr is None: | |||||
is_bad = True | |||||
break | |||||
sub_tokens.append(cur_substr) | |||||
start = end | |||||
if is_bad: | |||||
output_tokens.append(self.unk_token) | |||||
else: | |||||
output_tokens.extend(sub_tokens) | |||||
if len(output_tokens) == 0: # 防止里面全是空格或者回车符号 | |||||
return [self.unk_token] | |||||
return output_tokens | |||||
def load_vocab(vocab_file): | |||||
r"""Loads a vocabulary file into a dictionary.""" | |||||
vocab = collections.OrderedDict() | |||||
index = 0 | |||||
with open(vocab_file, "r", encoding="utf-8") as reader: | |||||
while True: | |||||
token = reader.readline() | |||||
if not token: | |||||
break | |||||
token = token.strip() | |||||
vocab[token] = index | |||||
index += 1 | |||||
return vocab | |||||
class BasicTokenizer(object): | |||||
r"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||||
def __init__(self, | |||||
do_lower_case=True, | |||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||||
r"""Constructs a BasicTokenizer. | |||||
Args: | |||||
do_lower_case: Whether to lower case the input. | |||||
""" | |||||
self.do_lower_case = do_lower_case | |||||
self.never_split = never_split | |||||
def tokenize(self, text): | |||||
r"""Tokenizes a piece of text.""" | |||||
text = self._clean_text(text) | |||||
# This was added on November 1st, 2018 for the multilingual and Chinese | |||||
# models. This is also applied to the English models now, but it doesn't | |||||
# matter since the English models were not trained on any Chinese data | |||||
# and generally don't have any Chinese data in them (there are Chinese | |||||
# characters in the vocabulary because Wikipedia does have some Chinese | |||||
# words in the English Wikipedia.). | |||||
text = self._tokenize_chinese_chars(text) | |||||
orig_tokens = whitespace_tokenize(text) | |||||
split_tokens = [] | |||||
for token in orig_tokens: | |||||
if self.do_lower_case and token not in self.never_split: | |||||
token = token.lower() | |||||
token = self._run_strip_accents(token) | |||||
split_tokens.extend(self._run_split_on_punc(token)) | |||||
output_tokens = whitespace_tokenize(" ".join(split_tokens)) | |||||
return output_tokens | |||||
def _run_strip_accents(self, text): | |||||
r"""Strips accents from a piece of text.""" | |||||
text = unicodedata.normalize("NFD", text) | |||||
output = [] | |||||
for char in text: | |||||
cat = unicodedata.category(char) | |||||
if cat == "Mn": | |||||
continue | |||||
output.append(char) | |||||
return "".join(output) | |||||
def _run_split_on_punc(self, text): | |||||
r"""Splits punctuation on a piece of text.""" | |||||
if text in self.never_split: | |||||
return [text] | |||||
chars = list(text) | |||||
i = 0 | |||||
start_new_word = True | |||||
output = [] | |||||
while i < len(chars): | |||||
char = chars[i] | |||||
if _is_punctuation(char): | |||||
output.append([char]) | |||||
start_new_word = True | |||||
else: | |||||
if start_new_word: | |||||
output.append([]) | |||||
start_new_word = False | |||||
output[-1].append(char) | |||||
i += 1 | |||||
return ["".join(x) for x in output] | |||||
def _tokenize_chinese_chars(self, text): | |||||
r"""Adds whitespace around any CJK character.""" | |||||
output = [] | |||||
for char in text: | |||||
cp = ord(char) | |||||
if self._is_chinese_char(cp): | |||||
output.append(" ") | |||||
output.append(char) | |||||
output.append(" ") | |||||
else: | |||||
output.append(char) | |||||
return "".join(output) | |||||
def _is_chinese_char(self, cp): | |||||
r"""Checks whether CP is the codepoint of a CJK character.""" | |||||
# This defines a "chinese character" as anything in the CJK Unicode block: | |||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||||
# | |||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||||
# despite its name. The modern Korean Hangul alphabet is a different block, | |||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||||
# space-separated words, so they are not treated specially and handled | |||||
# like the all of the other languages. | |||||
if (((cp >= 0x4E00) and (cp <= 0x9FFF)) or # | |||||
((cp >= 0x3400) and (cp <= 0x4DBF)) or # | |||||
((cp >= 0x20000) and (cp <= 0x2A6DF)) or # | |||||
((cp >= 0x2A700) and (cp <= 0x2B73F)) or # | |||||
((cp >= 0x2B740) and (cp <= 0x2B81F)) or # | |||||
((cp >= 0x2B820) and (cp <= 0x2CEAF)) or | |||||
((cp >= 0xF900) and (cp <= 0xFAFF)) or # | |||||
((cp >= 0x2F800) and (cp <= 0x2FA1F))): # | |||||
return True | |||||
return False | |||||
def _clean_text(self, text): | |||||
r"""Performs invalid character removal and whitespace cleanup on text.""" | |||||
output = [] | |||||
for char in text: | |||||
cp = ord(char) | |||||
if cp == 0 or cp == 0xfffd or _is_control(char): | |||||
continue | |||||
if _is_whitespace(char): | |||||
output.append(" ") | |||||
else: | |||||
output.append(char) | |||||
return "".join(output) | |||||
def _is_whitespace(char): | |||||
r"""Checks whether `chars` is a whitespace character.""" | |||||
# \t, \n, and \r are technically contorl characters but we treat them | |||||
# as whitespace since they are generally considered as such. | |||||
if char == " " or char == "\t" or char == "\n" or char == "\r": | |||||
return True | |||||
cat = unicodedata.category(char) | |||||
if cat == "Zs": | |||||
return True | |||||
return False | |||||
def _is_control(char): | |||||
r"""Checks whether `chars` is a control character.""" | |||||
# These are technically control characters but we count them as whitespace | |||||
# characters. | |||||
if char == "\t" or char == "\n" or char == "\r": | |||||
return False | |||||
cat = unicodedata.category(char) | |||||
if cat.startswith("C"): | |||||
return True | |||||
return False | |||||
def _is_punctuation(char): | |||||
r"""Checks whether `chars` is a punctuation character.""" | |||||
cp = ord(char) | |||||
# We treat all non-letter/number ASCII as punctuation. | |||||
# Characters such as "^", "$", and "`" are not in the Unicode | |||||
# Punctuation class but we treat them as punctuation anyways, for | |||||
# consistency. | |||||
if (((cp >= 33) and (cp <= 47)) or ((cp >= 58) and (cp <= 64)) or | |||||
((cp >= 91) and (cp <= 96)) or ((cp >= 123) and (cp <= 126))): | |||||
return True | |||||
cat = unicodedata.category(char) | |||||
if cat.startswith("P"): | |||||
return True | |||||
return False | |||||
class BertTokenizer(object): | |||||
r"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, | |||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||||
r"""Constructs a BertTokenizer. | |||||
Args: | |||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||||
do_lower_case: Whether to lower case the input | |||||
Only has an effect when do_wordpiece_only=False | |||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||||
max_len: An artificial maximum length to truncate tokenized sequences to; | |||||
Effective maximum length is always the minimum of this | |||||
value (if specified) and the underlying BERT model's | |||||
sequence length. | |||||
never_split: List of tokens which will never be split during tokenization. | |||||
Only has an effect when do_wordpiece_only=False | |||||
""" | |||||
if not os.path.isfile(vocab_file): | |||||
raise ValueError( | |||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) | |||||
self.vocab = load_vocab(vocab_file) | |||||
self.ids_to_tokens = collections.OrderedDict( | |||||
[(ids, tok) for tok, ids in self.vocab.items()]) | |||||
self.do_basic_tokenize = do_basic_tokenize | |||||
if do_basic_tokenize: | |||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, | |||||
never_split=never_split) | |||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||||
self.max_len = max_len if max_len is not None else int(1e12) | |||||
def _reinit_on_new_vocab(self, vocab): | |||||
r""" | |||||
在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质 | |||||
:param vocab: | |||||
:return: | |||||
""" | |||||
self.vocab = vocab | |||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||||
def tokenize(self, text): | |||||
split_tokens = [] | |||||
if self.do_basic_tokenize: | |||||
for token in self.basic_tokenizer.tokenize(text): | |||||
for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||||
split_tokens.append(sub_token) | |||||
else: | |||||
split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||||
return split_tokens | |||||
def convert_tokens_to_ids(self, tokens): | |||||
r"""Converts a sequence of tokens into ids using the vocab.""" | |||||
ids = [] | |||||
for token in tokens: | |||||
ids.append(self.vocab[token]) | |||||
if len(ids) > self.max_len: | |||||
logger.warning( | |||||
"Token indices sequence length is longer than the specified maximum " | |||||
" sequence length for this BERT model ({} > {}). Running this" | |||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) | |||||
) | |||||
return ids | |||||
def convert_ids_to_tokens(self, ids): | |||||
r"""Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||||
tokens = [] | |||||
for i in ids: | |||||
tokens.append(self.ids_to_tokens[i]) | |||||
return tokens | |||||
def save_vocabulary(self, vocab_path): | |||||
r"""Save the tokenizer vocabulary to a directory or file.""" | |||||
index = 0 | |||||
if os.path.isdir(vocab_path): | |||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME) | |||||
else: | |||||
vocab_file = vocab_path | |||||
with open(vocab_file, "w", encoding="utf-8") as writer: | |||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): | |||||
if index != token_index: | |||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||||
" Please check that the vocabulary is not corrupted!".format(vocab_file)) | |||||
index = token_index | |||||
writer.write(token + u'\n') | |||||
index += 1 | |||||
return vocab_file | |||||
@classmethod | |||||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||||
r""" | |||||
给定模型的名字或者路径,直接读取vocab. | |||||
""" | |||||
model_dir = _get_bert_dir(model_dir_or_name) | |||||
pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt') | |||||
logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path)) | |||||
max_len = 512 | |||||
kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len) | |||||
# Instantiate tokenizer. | |||||
tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs) | |||||
return tokenizer | |||||
class _WordPieceBertModel(nn.Module): | |||||
r""" | |||||
这个模块用于直接计算word_piece的结果. | |||||
""" | |||||
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | |||||
super().__init__() | |||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = BertModel.from_pretrained(model_dir_or_name) | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
self.layers = list(map(int, layers.split(','))) | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a bert model with {encoder_layer_number} layers." | |||||
else: | |||||
assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||||
f"a bert model with {encoder_layer_number} layers." | |||||
self._cls_index = self.tokenzier.vocab['[CLS]'] | |||||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||||
self._wordpiece_unknown_index = self.tokenzier.vocab['[UNK]'] | |||||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||||
self.pooled_cls = pooled_cls | |||||
def index_dataset(self, *datasets, field_name, add_cls_sep=True): | |||||
r""" | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||||
:param datasets: DataSet对象 | |||||
:param field_name: 基于哪一列index | |||||
:return: | |||||
""" | |||||
def convert_words_to_word_pieces(words): | |||||
word_pieces = [] | |||||
for word in words: | |||||
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() | |||||
tokens = [] | |||||
for word in _words: | |||||
tokens.extend(self.tokenzier.wordpiece_tokenizer.tokenize(word)) | |||||
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) | |||||
word_pieces.extend(word_piece_ids) | |||||
if add_cls_sep: | |||||
if word_pieces[0] != self._cls_index: | |||||
word_pieces.insert(0, self._cls_index) | |||||
if word_pieces[-1] != self._sep_index: | |||||
word_pieces.append(self._sep_index) | |||||
return word_pieces | |||||
for index, dataset in enumerate(datasets): | |||||
try: | |||||
dataset.apply_field(convert_words_to_word_pieces, field_name=field_name, new_field_name='word_pieces', | |||||
is_input=True) | |||||
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when processing the {index} dataset.") | |||||
raise e | |||||
def forward(self, word_pieces, token_type_ids=None): | |||||
r""" | |||||
:param word_pieces: torch.LongTensor, batch_size x max_len | |||||
:param token_type_ids: torch.LongTensor, batch_size x max_len | |||||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||||
""" | |||||
batch_size, max_len = word_pieces.size() | |||||
attn_masks = word_pieces.ne(self._wordpiece_pad_index) | |||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||||
output_all_encoded_layers=True) | |||||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||||
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1))) | |||||
for l_index, l in enumerate(self.layers): | |||||
bert_output = bert_outputs[l] | |||||
if l in (len(bert_outputs)-1, -1) and self.pooled_cls: | |||||
bert_output[:, 0] = pooled_cls | |||||
outputs[l_index] = bert_output | |||||
return outputs |
@@ -0,0 +1,182 @@ | |||||
r"""undocumented | |||||
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | |||||
有用,也请引用一下他们。 | |||||
""" | |||||
__all__ = [ | |||||
'RobertaModel' | |||||
] | |||||
import torch | |||||
import torch.nn as nn | |||||
from .bert import BertEmbeddings, BertModel, BertConfig | |||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||||
from ...io.file_utils import _get_roberta_dir | |||||
from ...core import logger | |||||
PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = { | |||||
"roberta-base": 512, | |||||
"roberta-large": 512, | |||||
"roberta-large-mnli": 512, | |||||
"distilroberta-base": 512, | |||||
"roberta-base-openai-detector": 512, | |||||
"roberta-large-openai-detector": 512, | |||||
} | |||||
class RobertaEmbeddings(BertEmbeddings): | |||||
""" | |||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. | |||||
""" | |||||
def __init__(self, config): | |||||
super().__init__(config) | |||||
self.padding_idx = 1 | |||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) | |||||
self.position_embeddings = nn.Embedding( | |||||
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx | |||||
) | |||||
def forward(self, input_ids, token_type_ids, words_embeddings=None): | |||||
position_ids = self.create_position_ids_from_input_ids(input_ids) | |||||
return super().forward( | |||||
input_ids, token_type_ids=token_type_ids, position_ids=position_ids, words_embeddings=words_embeddings | |||||
) | |||||
def create_position_ids_from_input_ids(self, x): | |||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at | |||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's | |||||
`utils.make_positions`. | |||||
:param torch.Tensor x: | |||||
:return torch.Tensor: | |||||
""" | |||||
mask = x.ne(self.padding_idx).long() | |||||
incremental_indicies = torch.cumsum(mask, dim=1) * mask | |||||
return incremental_indicies + self.padding_idx | |||||
class RobertaModel(BertModel): | |||||
r""" | |||||
undocumented | |||||
""" | |||||
def __init__(self, config): | |||||
super().__init__(config) | |||||
self.embeddings = RobertaEmbeddings(config) | |||||
self.apply(self.init_bert_weights) | |||||
@classmethod | |||||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||||
state_dict = kwargs.get('state_dict', None) | |||||
kwargs.pop('state_dict', None) | |||||
kwargs.pop('cache_dir', None) | |||||
kwargs.pop('from_tf', None) | |||||
# get model dir from name or dir | |||||
pretrained_model_dir = _get_roberta_dir(model_dir_or_name) | |||||
# Load config | |||||
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, 'config.json') | |||||
config = BertConfig.from_json_file(config_file) | |||||
# Load model | |||||
if state_dict is None: | |||||
weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') | |||||
state_dict = torch.load(weights_path, map_location='cpu') | |||||
else: | |||||
logger.error(f'Cannot load parameters through `state_dict` variable.') | |||||
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | |||||
# Instantiate model. | |||||
model = cls(config, *inputs, **kwargs) | |||||
missing_keys = [] | |||||
unexpected_keys = [] | |||||
error_msgs = [] | |||||
# Convert old format to new format if needed from a PyTorch state_dict | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if "gamma" in key: | |||||
new_key = key.replace("gamma", "weight") | |||||
if "beta" in key: | |||||
new_key = key.replace("beta", "bias") | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
# copy state_dict so _load_from_state_dict can modify it | |||||
metadata = getattr(state_dict, "_metadata", None) | |||||
state_dict = state_dict.copy() | |||||
if metadata is not None: | |||||
state_dict._metadata = metadata | |||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |||||
# so we need to apply the function recursively. | |||||
def load(module: nn.Module, prefix=""): | |||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||||
module._load_from_state_dict( | |||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, | |||||
) | |||||
for name, child in module._modules.items(): | |||||
if child is not None: | |||||
load(child, prefix + name + ".") | |||||
# Make sure we are able to load base models as well as derived models (with heads) | |||||
start_prefix = "" | |||||
model_to_load = model | |||||
if not hasattr(model, 'roberta') and any( | |||||
s.startswith('roberta') for s in state_dict.keys() | |||||
): | |||||
start_prefix = 'roberta.' | |||||
if hasattr(model, 'roberta') and not any( | |||||
s.startswith('roberta') for s in state_dict.keys() | |||||
): | |||||
model_to_load = getattr(model, 'roberta') | |||||
load(model_to_load, prefix=start_prefix) | |||||
if model.__class__.__name__ != model_to_load.__class__.__name__: | |||||
base_model_state_dict = model_to_load.state_dict().keys() | |||||
head_model_state_dict_without_base_prefix = [ | |||||
key.split('roberta.')[-1] for key in model.state_dict().keys() | |||||
] | |||||
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) | |||||
if len(missing_keys) > 0: | |||||
logger.info( | |||||
"Weights of {} not initialized from pretrained model: {}".format( | |||||
model.__class__.__name__, missing_keys | |||||
) | |||||
) | |||||
if len(unexpected_keys) > 0: | |||||
logger.info( | |||||
"Weights from pretrained model not used in {}: {}".format( | |||||
model.__class__.__name__, unexpected_keys | |||||
) | |||||
) | |||||
if len(error_msgs) > 0: | |||||
raise RuntimeError( | |||||
"Error(s) in loading state_dict for {}:\n\t{}".format( | |||||
model.__class__.__name__, "\n\t".join(error_msgs) | |||||
) | |||||
) | |||||
# Set model in evaluation mode to desactivate DropOut modules by default | |||||
model.eval() | |||||
logger.info(f"Load pre-trained RoBERTa parameters from file {weights_path}.") | |||||
return model | |||||
@@ -0,0 +1,452 @@ | |||||
r""" | |||||
""" | |||||
__all__ = [ | |||||
'SequenceGenerator' | |||||
] | |||||
import torch | |||||
from ..decoder.seq2seq_decoder import Decoder | |||||
import torch.nn.functional as F | |||||
from fastNLP.core.utils import _get_model_device | |||||
from functools import partial | |||||
class SequenceGenerator: | |||||
def __init__(self, decoder: Decoder, max_length=20, num_beams=1, | |||||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||||
if do_sample: | |||||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, | |||||
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | |||||
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, | |||||
length_penalty=length_penalty, pad_token_id=pad_token_id) | |||||
else: | |||||
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, | |||||
repetition_penalty=repetition_penalty, | |||||
length_penalty=length_penalty, pad_token_id=pad_token_id) | |||||
self.do_sample = do_sample | |||||
self.max_length = max_length | |||||
self.num_beams = num_beams | |||||
self.temperature = temperature | |||||
self.top_k = top_k | |||||
self.top_p = top_p | |||||
self.bos_token_id = bos_token_id | |||||
self.eos_token_id = eos_token_id | |||||
self.repetition_penalty = repetition_penalty | |||||
self.length_penalty = length_penalty | |||||
self.decoder = decoder | |||||
@torch.no_grad() | |||||
def generate(self, tokens=None, past=None): | |||||
""" | |||||
:param torch.LongTensor tokens: batch_size x length, 开始的token | |||||
:param past: | |||||
:return: | |||||
""" | |||||
# TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode? | |||||
return self.generate_func(tokens=tokens, past=past) | |||||
@torch.no_grad() | |||||
def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||||
bos_token_id=None, eos_token_id=None, pad_token_id=0, | |||||
repetition_penalty=1, length_penalty=1.0): | |||||
""" | |||||
贪婪地搜索句子 | |||||
:param Decoder decoder: Decoder对象 | |||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||||
:param Past past: 应该包好encoder的一些输出。 | |||||
:param int max_length: 生成句子的最大长度。 | |||||
:param int num_beams: 使用多大的beam进行解码。 | |||||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||||
:param int pad_token_id: | |||||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||||
:return: | |||||
""" | |||||
if num_beams == 1: | |||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
else: | |||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||||
temperature=1, top_k=50, top_p=1, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
return token_ids | |||||
@torch.no_grad() | |||||
def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, | |||||
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | |||||
length_penalty=1.0): | |||||
""" | |||||
使用采样的方法生成句子 | |||||
:param Decoder decoder: Decoder对象 | |||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||||
:param Past past: 应该包好encoder的一些输出。 | |||||
:param int max_length: 生成句子的最大长度。 | |||||
:param int num_beam: 使用多大的beam进行解码。 | |||||
:param float temperature: 采样时的退火大小 | |||||
:param int top_k: 只在top_k的sample里面采样 | |||||
:param float top_p: 介于0,1的值。 | |||||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||||
:param int pad_token_id: pad的token id | |||||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||||
:return: | |||||
""" | |||||
# 每个位置在生成的时候会sample生成 | |||||
if num_beams == 1: | |||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature, | |||||
top_k=top_k, top_p=top_p, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
else: | |||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||||
temperature=temperature, top_k=top_k, top_p=top_p, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
return token_ids | |||||
def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50, | |||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||||
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | |||||
device = _get_model_device(decoder) | |||||
if tokens is None: | |||||
if bos_token_id is None: | |||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||||
if past is None: | |||||
raise RuntimeError("You have to specify either `past` or `tokens`.") | |||||
batch_size = past.num_samples() | |||||
if batch_size is None: | |||||
raise RuntimeError("Cannot infer the number of samples from `past`.") | |||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||||
batch_size = tokens.size(0) | |||||
if past is not None: | |||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||||
if eos_token_id is None: | |||||
_eos_token_id = float('nan') | |||||
else: | |||||
_eos_token_id = eos_token_id | |||||
# for i in range(tokens.size(1)): | |||||
# scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) | |||||
token_ids = tokens.clone() | |||||
cur_len = token_ids.size(1) | |||||
dones = token_ids.new_zeros(batch_size).eq(1) | |||||
# tokens = tokens[:, -1:] | |||||
while cur_len < max_length: | |||||
# scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past | |||||
if repetition_penalty != 1.0: | |||||
token_scores = scores.gather(dim=1, index=token_ids) | |||||
lt_zero_mask = token_scores.lt(0).float() | |||||
ge_zero_mask = lt_zero_mask.eq(0).float() | |||||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||||
scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||||
if eos_token_id is not None and length_penalty != 1.0: | |||||
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size | |||||
eos_mask = scores.new_ones(scores.size(1)) | |||||
eos_mask[eos_token_id] = 0 | |||||
eos_mask = eos_mask.unsqueeze(0).eq(1) | |||||
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 | |||||
if do_sample: | |||||
if temperature > 0 and temperature != 1: | |||||
scores = scores / temperature | |||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | |||||
probs = F.softmax(scores, dim=-1) | |||||
# 保证至少有一个不是eos的值 | |||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | |||||
else: | |||||
next_tokens = torch.argmax(scores, dim=-1) # batch_size | |||||
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding | |||||
tokens = next_tokens.unsqueeze(1) | |||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | |||||
end_mask = next_tokens.eq(_eos_token_id) | |||||
dones = dones.__or__(end_mask) | |||||
cur_len += 1 | |||||
if dones.min() == 1: | |||||
break | |||||
if eos_token_id is not None: | |||||
if cur_len == max_length: | |||||
token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos | |||||
return token_ids | |||||
def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0, | |||||
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||||
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | |||||
# 进行beam search | |||||
device = _get_model_device(decoder) | |||||
if tokens is None: | |||||
if bos_token_id is None: | |||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||||
if past is None: | |||||
raise RuntimeError("You have to specify either `past` or `tokens`.") | |||||
batch_size = past.num_samples() | |||||
if batch_size is None: | |||||
raise RuntimeError("Cannot infer the number of samples from `past`.") | |||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||||
batch_size = tokens.size(0) | |||||
if past is not None: | |||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||||
# for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | |||||
# scores, past = decoder.decode_one(tokens[:, :i + 1], | |||||
# past) # (batch_size, vocab_size), Past | |||||
# scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 | |||||
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 | |||||
vocab_size = scores.size(1) | |||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||||
if do_sample: | |||||
probs = F.softmax(scores, dim=-1) | |||||
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | |||||
logits = probs.log() | |||||
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | |||||
else: | |||||
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) | |||||
# 得到(batch_size, num_beams), (batch_size, num_beams) | |||||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | |||||
indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||||
indices = indices.repeat_interleave(num_beams) | |||||
decoder.reorder_past(indices, past) | |||||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||||
# 记录生成好的token (batch_size', cur_len) | |||||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | |||||
dones = [False] * batch_size | |||||
tokens = next_tokens.view(-1, 1) | |||||
beam_scores = next_scores.view(-1) # batch_size * num_beams | |||||
# 用来记录已经生成好的token的长度 | |||||
cur_len = token_ids.size(1) | |||||
hypos = [ | |||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | |||||
] | |||||
# 0,num_beams, 2*num_beams, ... | |||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||||
while cur_len < max_length: | |||||
# scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) | |||||
if repetition_penalty != 1.0: | |||||
token_scores = scores.gather(dim=1, index=token_ids) | |||||
lt_zero_mask = token_scores.lt(0).float() | |||||
ge_zero_mask = lt_zero_mask.eq(0).float() | |||||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||||
scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||||
if do_sample: | |||||
if temperature > 0 and temperature != 1: | |||||
scores = scores / temperature | |||||
# 多召回一个防止eos | |||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | |||||
probs = F.softmax(scores, dim=-1) | |||||
# 保证至少有一个不是eos的值 | |||||
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | |||||
logits = probs.log() | |||||
# 防止全是这个beam的被选中了,且需要考虑eos被选择的情况 | |||||
_scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1) | |||||
_scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1) | |||||
# 从这里面再选择top的2*num_beam个 | |||||
_scores = _scores.view(batch_size, num_beams * (num_beams + 1)) | |||||
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) | |||||
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) | |||||
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) | |||||
from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) | |||||
else: | |||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) | |||||
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) | |||||
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) | |||||
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) | |||||
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) | |||||
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) | |||||
# 接下来需要组装下一个batch的结果。 | |||||
# 需要选定哪些留下来 | |||||
next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) | |||||
next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||||
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||||
not_eos_mask = next_tokens.ne(eos_token_id) # 为1的地方不是eos | |||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | |||||
keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | |||||
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1) | |||||
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam | |||||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | |||||
beam_scores = _next_scores.view(-1) | |||||
# 更改past状态, 重组token_ids | |||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||||
decoder.reorder_past(reorder_inds, past) | |||||
flag = True | |||||
if cur_len + 1 == max_length: | |||||
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | |||||
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | |||||
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | |||||
else: | |||||
# 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | |||||
effective_eos_mask = next_tokens[:, :num_beams].eq(eos_token_id) # batch_size x num_beams | |||||
if effective_eos_mask.sum().gt(0): | |||||
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | |||||
# 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | |||||
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind | |||||
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | |||||
else: | |||||
flag = False | |||||
if flag: | |||||
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | |||||
eos_beam_idx.tolist()): | |||||
if not dones[batch_idx]: | |||||
score = next_scores[batch_idx, beam_ind].item() | |||||
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) | |||||
# 重新组织token_ids的状态 | |||||
tokens = _next_tokens | |||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) | |||||
for batch_idx in range(batch_size): | |||||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) | |||||
cur_len += 1 | |||||
if all(dones): | |||||
break | |||||
# select the best hypotheses | |||||
tgt_len = token_ids.new(batch_size) | |||||
best = [] | |||||
for i, hypotheses in enumerate(hypos): | |||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | |||||
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol | |||||
best.append(best_hyp) | |||||
# generate target batch | |||||
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||||
for i, hypo in enumerate(best): | |||||
decoded[i, :tgt_len[i] - 1] = hypo | |||||
if eos_token_id is not None: | |||||
decoded[i, tgt_len[i] - 1] = eos_token_id | |||||
return decoded | |||||
class BeamHypotheses(object): | |||||
def __init__(self, num_beams, max_length, length_penalty, early_stopping): | |||||
""" | |||||
Initialize n-best list of hypotheses. | |||||
""" | |||||
self.max_length = max_length - 1 # ignoring bos_token | |||||
self.length_penalty = length_penalty | |||||
self.early_stopping = early_stopping | |||||
self.num_beams = num_beams | |||||
self.hyp = [] | |||||
self.worst_score = 1e9 | |||||
def __len__(self): | |||||
""" | |||||
Number of hypotheses in the list. | |||||
""" | |||||
return len(self.hyp) | |||||
def add(self, hyp, sum_logprobs): | |||||
""" | |||||
Add a new hypothesis to the list. | |||||
""" | |||||
score = sum_logprobs / len(hyp) ** self.length_penalty | |||||
if len(self) < self.num_beams or score > self.worst_score: | |||||
self.hyp.append((score, hyp)) | |||||
if len(self) > self.num_beams: | |||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) | |||||
del self.hyp[sorted_scores[0][1]] | |||||
self.worst_score = sorted_scores[1][0] | |||||
else: | |||||
self.worst_score = min(score, self.worst_score) | |||||
def is_done(self, best_sum_logprobs): | |||||
""" | |||||
If there are enough hypotheses and that none of the hypotheses being generated | |||||
can become better than the worst one in the heap, then we are done with this sentence. | |||||
""" | |||||
if len(self) < self.num_beams: | |||||
return False | |||||
elif self.early_stopping: | |||||
return True | |||||
else: | |||||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty | |||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): | |||||
""" | |||||
根据top_k, top_p的值,将不满足的值置为filter_value的值 | |||||
:param torch.Tensor logits: bsz x vocab_size | |||||
:param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value | |||||
:param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式 | |||||
:param float filter_value: | |||||
:param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值 | |||||
:return: | |||||
""" | |||||
if top_k > 0: | |||||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check | |||||
# Remove all tokens with a probability less than the last token of the top-k | |||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |||||
logits[indices_to_remove] = filter_value | |||||
if top_p < 1.0: | |||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |||||
sorted_indices_to_remove = cumulative_probs > top_p | |||||
if min_tokens_to_keep > 1: | |||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |||||
# Shift the indices to the right to keep also the first token above the threshold | |||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |||||
sorted_indices_to_remove[..., 0] = 0 | |||||
# scatter sorted tensors to original indexing | |||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |||||
logits[indices_to_remove] = filter_value | |||||
return logits |
@@ -0,0 +1,14 @@ | |||||
r""" | |||||
""" | |||||
__all__=[ | |||||
'BertTokenizer', | |||||
"GPT2Tokenizer", | |||||
"RobertaTokenizer" | |||||
] | |||||
from .bert_tokenizer import BertTokenizer | |||||
from .gpt2_tokenizer import GPT2Tokenizer | |||||
from .roberta_tokenizer import RobertaTokenizer |
@@ -0,0 +1,447 @@ | |||||
r""" | |||||
""" | |||||
__all__ = [ | |||||
'BertTokenizer' | |||||
] | |||||
import os | |||||
import collections | |||||
import unicodedata | |||||
from ...core import logger | |||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||||
from ...io.file_utils import _get_bert_dir | |||||
VOCAB_NAME = 'vocab.txt' | |||||
PRETRAINED_INIT_CONFIGURATION = { | |||||
"en": {"do_lower_case": False}, | |||||
"en-base-uncased": {'do_lower_case': True}, | |||||
'en-base-cased': {'do_lower_case':False}, | |||||
"en-large-cased-wwm": {"do_lower_case": False}, | |||||
'en-large-cased': {'do_lower_case':False}, | |||||
'en-large-uncased': {'do_lower_case':True}, | |||||
'en-large-uncased-wwm': {'do_lower_case':True}, | |||||
'cn': {'do_lower_case':True}, | |||||
'cn-base': {'do_lower_case': True}, | |||||
'cn-wwm-ext': {'do_lower_case': True}, | |||||
'multi-base-cased': {'do_lower_case': False}, | |||||
'multi-base-uncased': {'do_lower_case': True}, | |||||
} | |||||
def _is_control(char): | |||||
r"""Checks whether `chars` is a control character.""" | |||||
# These are technically control characters but we count them as whitespace | |||||
# characters. | |||||
if char == "\t" or char == "\n" or char == "\r": | |||||
return False | |||||
cat = unicodedata.category(char) | |||||
if cat.startswith("C"): | |||||
return True | |||||
return False | |||||
def _is_punctuation(char): | |||||
r"""Checks whether `chars` is a punctuation character.""" | |||||
cp = ord(char) | |||||
# We treat all non-letter/number ASCII as punctuation. | |||||
# Characters such as "^", "$", and "`" are not in the Unicode | |||||
# Punctuation class but we treat them as punctuation anyways, for | |||||
# consistency. | |||||
if (((cp >= 33) and (cp <= 47)) or ((cp >= 58) and (cp <= 64)) or | |||||
((cp >= 91) and (cp <= 96)) or ((cp >= 123) and (cp <= 126))): | |||||
return True | |||||
cat = unicodedata.category(char) | |||||
if cat.startswith("P"): | |||||
return True | |||||
return False | |||||
def _is_whitespace(char): | |||||
r"""Checks whether `chars` is a whitespace character.""" | |||||
# \t, \n, and \r are technically contorl characters but we treat them | |||||
# as whitespace since they are generally considered as such. | |||||
if char == " " or char == "\t" or char == "\n" or char == "\r": | |||||
return True | |||||
cat = unicodedata.category(char) | |||||
if cat == "Zs": | |||||
return True | |||||
return False | |||||
def whitespace_tokenize(text): | |||||
r"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
text = text.strip() | |||||
if not text: | |||||
return [] | |||||
tokens = text.split() | |||||
return tokens | |||||
class BasicTokenizer(object): | |||||
r"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||||
def __init__(self, | |||||
do_lower_case=True, | |||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||||
r"""Constructs a BasicTokenizer. | |||||
Args: | |||||
do_lower_case: Whether to lower case the input. | |||||
""" | |||||
self.do_lower_case = do_lower_case | |||||
self.never_split = never_split | |||||
def tokenize(self, text): | |||||
r"""Tokenizes a piece of text.""" | |||||
text = self._clean_text(text) | |||||
# This was added on November 1st, 2018 for the multilingual and Chinese | |||||
# models. This is also applied to the English models now, but it doesn't | |||||
# matter since the English models were not trained on any Chinese data | |||||
# and generally don't have any Chinese data in them (there are Chinese | |||||
# characters in the vocabulary because Wikipedia does have some Chinese | |||||
# words in the English Wikipedia.). | |||||
text = self._tokenize_chinese_chars(text) | |||||
orig_tokens = whitespace_tokenize(text) | |||||
split_tokens = [] | |||||
for token in orig_tokens: | |||||
if self.do_lower_case and token not in self.never_split: | |||||
token = token.lower() | |||||
token = self._run_strip_accents(token) | |||||
split_tokens.extend(self._run_split_on_punc(token)) | |||||
output_tokens = whitespace_tokenize(" ".join(split_tokens)) | |||||
return output_tokens | |||||
def _run_strip_accents(self, text): | |||||
r"""Strips accents from a piece of text.""" | |||||
text = unicodedata.normalize("NFD", text) | |||||
output = [] | |||||
for char in text: | |||||
cat = unicodedata.category(char) | |||||
if cat == "Mn": | |||||
continue | |||||
output.append(char) | |||||
return "".join(output) | |||||
def _run_split_on_punc(self, text): | |||||
r"""Splits punctuation on a piece of text.""" | |||||
if text in self.never_split: | |||||
return [text] | |||||
chars = list(text) | |||||
i = 0 | |||||
start_new_word = True | |||||
output = [] | |||||
while i < len(chars): | |||||
char = chars[i] | |||||
if _is_punctuation(char): | |||||
output.append([char]) | |||||
start_new_word = True | |||||
else: | |||||
if start_new_word: | |||||
output.append([]) | |||||
start_new_word = False | |||||
output[-1].append(char) | |||||
i += 1 | |||||
return ["".join(x) for x in output] | |||||
def _tokenize_chinese_chars(self, text): | |||||
r"""Adds whitespace around any CJK character.""" | |||||
output = [] | |||||
for char in text: | |||||
cp = ord(char) | |||||
if self._is_chinese_char(cp): | |||||
output.append(" ") | |||||
output.append(char) | |||||
output.append(" ") | |||||
else: | |||||
output.append(char) | |||||
return "".join(output) | |||||
def _is_chinese_char(self, cp): | |||||
r"""Checks whether CP is the codepoint of a CJK character.""" | |||||
# This defines a "chinese character" as anything in the CJK Unicode block: | |||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||||
# | |||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||||
# despite its name. The modern Korean Hangul alphabet is a different block, | |||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||||
# space-separated words, so they are not treated specially and handled | |||||
# like the all of the other languages. | |||||
if (((cp >= 0x4E00) and (cp <= 0x9FFF)) or # | |||||
((cp >= 0x3400) and (cp <= 0x4DBF)) or # | |||||
((cp >= 0x20000) and (cp <= 0x2A6DF)) or # | |||||
((cp >= 0x2A700) and (cp <= 0x2B73F)) or # | |||||
((cp >= 0x2B740) and (cp <= 0x2B81F)) or # | |||||
((cp >= 0x2B820) and (cp <= 0x2CEAF)) or | |||||
((cp >= 0xF900) and (cp <= 0xFAFF)) or # | |||||
((cp >= 0x2F800) and (cp <= 0x2FA1F))): # | |||||
return True | |||||
return False | |||||
def _clean_text(self, text): | |||||
r"""Performs invalid character removal and whitespace cleanup on text.""" | |||||
output = [] | |||||
for char in text: | |||||
cp = ord(char) | |||||
if cp == 0 or cp == 0xfffd or _is_control(char): | |||||
continue | |||||
if _is_whitespace(char): | |||||
output.append(" ") | |||||
else: | |||||
output.append(char) | |||||
return "".join(output) | |||||
def load_vocab(vocab_file): | |||||
r"""Loads a vocabulary file into a dictionary.""" | |||||
vocab = collections.OrderedDict() | |||||
index = 0 | |||||
with open(vocab_file, "r", encoding="utf-8") as reader: | |||||
while True: | |||||
token = reader.readline() | |||||
if not token: | |||||
break | |||||
token = token.strip() | |||||
vocab[token] = index | |||||
index += 1 | |||||
return vocab | |||||
class WordpieceTokenizer(object): | |||||
r"""Runs WordPiece tokenization.""" | |||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): | |||||
self.vocab = vocab | |||||
self.unk_token = unk_token | |||||
self.max_input_chars_per_word = max_input_chars_per_word | |||||
def tokenize(self, text): | |||||
r"""Tokenizes a piece of text into its word pieces. | |||||
This uses a greedy longest-match-first algorithm to perform tokenization | |||||
using the given vocabulary. | |||||
For example: | |||||
input = "unaffable" | |||||
output = ["un", "##aff", "##able"] | |||||
Args: | |||||
text: A single token or whitespace separated tokens. This should have | |||||
already been passed through `BasicTokenizer`. | |||||
Returns: | |||||
A list of wordpiece tokens. | |||||
""" | |||||
output_tokens = [] | |||||
for token in whitespace_tokenize(text): | |||||
chars = list(token) | |||||
if len(chars) > self.max_input_chars_per_word: | |||||
output_tokens.append(self.unk_token) | |||||
continue | |||||
is_bad = False | |||||
start = 0 | |||||
sub_tokens = [] | |||||
while start < len(chars): | |||||
end = len(chars) | |||||
cur_substr = None | |||||
while start < end: | |||||
substr = "".join(chars[start:end]) | |||||
if start > 0: | |||||
substr = "##" + substr | |||||
if substr in self.vocab: | |||||
cur_substr = substr | |||||
break | |||||
end -= 1 | |||||
if cur_substr is None: | |||||
is_bad = True | |||||
break | |||||
sub_tokens.append(cur_substr) | |||||
start = end | |||||
if is_bad: | |||||
output_tokens.append(self.unk_token) | |||||
else: | |||||
output_tokens.extend(sub_tokens) | |||||
if len(output_tokens) == 0: # 防止里面全是空格或者回车符号 | |||||
return [self.unk_token] | |||||
return output_tokens | |||||
class BertTokenizer(object): | |||||
r"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, | |||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||||
r"""Constructs a BertTokenizer. | |||||
Args: | |||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||||
do_lower_case: Whether to lower case the input | |||||
Only has an effect when do_wordpiece_only=False | |||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||||
max_len: An artificial maximum length to truncate tokenized sequences to; | |||||
Effective maximum length is always the minimum of this | |||||
value (if specified) and the underlying BERT model's | |||||
sequence length. | |||||
never_split: List of tokens which will never be split during tokenization. | |||||
Only has an effect when do_wordpiece_only=False | |||||
""" | |||||
if not os.path.isfile(vocab_file): | |||||
raise ValueError( | |||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) | |||||
self.vocab = load_vocab(vocab_file) | |||||
self.ids_to_tokens = collections.OrderedDict( | |||||
[(ids, tok) for tok, ids in self.vocab.items()]) | |||||
self.do_basic_tokenize = do_basic_tokenize | |||||
if do_basic_tokenize: | |||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, | |||||
never_split=never_split) | |||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||||
self.max_len = max_len if max_len is not None else int(1e12) | |||||
@property | |||||
def unk_index(self): | |||||
return self.vocab['[UNK]'] | |||||
@property | |||||
def pad_index(self): | |||||
return self.vocab['[PAD]'] | |||||
@property | |||||
def cls_index(self): | |||||
return self.vocab['[CLS]'] | |||||
@property | |||||
def sep_index(self): | |||||
return self.vocab['[SEP]'] | |||||
def _reinit_on_new_vocab(self, vocab): | |||||
r""" | |||||
在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质 | |||||
:param vocab: | |||||
:return: | |||||
""" | |||||
self.vocab = vocab | |||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||||
def tokenize(self, text): | |||||
split_tokens = [] | |||||
if self.do_basic_tokenize: | |||||
for token in self.basic_tokenizer.tokenize(text): | |||||
for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||||
split_tokens.append(sub_token) | |||||
else: | |||||
split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||||
return split_tokens | |||||
def convert_tokens_to_ids(self, tokens): | |||||
r"""Converts a sequence of tokens into ids using the vocab.""" | |||||
ids = [] | |||||
for token in tokens: | |||||
ids.append(self.vocab[token]) | |||||
if len(ids) > self.max_len: | |||||
logger.warning( | |||||
"Token indices sequence length is longer than the specified maximum " | |||||
" sequence length for this BERT model ({} > {}). Running this" | |||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) | |||||
) | |||||
return ids | |||||
def convert_ids_to_tokens(self, ids): | |||||
r"""将token ids转换为一句话""" | |||||
tokens = [] | |||||
for i in ids: | |||||
tokens.append(self.ids_to_tokens[i]) | |||||
return self._convert_tokens_to_string(tokens) | |||||
def _convert_tokens_to_string(self, tokens): | |||||
""" Converts a sequence of tokens (string) in a single string. """ | |||||
out_string = " ".join(tokens).replace(" ##", "").strip() | |||||
return out_string | |||||
def save_vocabulary(self, vocab_path): | |||||
r"""Save the tokenizer vocabulary to a directory or file.""" | |||||
index = 0 | |||||
if os.path.isdir(vocab_path): | |||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME) | |||||
else: | |||||
vocab_file = vocab_path | |||||
with open(vocab_file, "w", encoding="utf-8") as writer: | |||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): | |||||
if index != token_index: | |||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||||
" Please check that the vocabulary is not corrupted!".format(vocab_file)) | |||||
index = token_index | |||||
writer.write(token + u'\n') | |||||
index += 1 | |||||
return vocab_file | |||||
@classmethod | |||||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||||
r""" | |||||
给定模型的名字或者路径,直接读取vocab. | |||||
""" | |||||
model_dir = _get_bert_dir(model_dir_or_name) | |||||
pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt') | |||||
logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path)) | |||||
max_len = 512 | |||||
kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len) | |||||
# Instantiate tokenizer. | |||||
if 'do_lower_case' not in kwargs: | |||||
if model_dir_or_name in PRETRAINED_INIT_CONFIGURATION: | |||||
kwargs['do_lower_case'] = PRETRAINED_INIT_CONFIGURATION[model_dir_or_name]['do_lower_case'] | |||||
else: | |||||
if 'case' in model_dir_or_name: | |||||
kwargs['do_lower_case'] = False | |||||
elif 'uncase' in model_dir_or_name: | |||||
kwargs['do_lower_case'] = True | |||||
tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs) | |||||
return tokenizer | |||||
def encode(self, text, add_special_tokens=True): | |||||
""" | |||||
给定text输入将数据encode为index的形式。 | |||||
Example:: | |||||
>>> from fastNLP.modules import BertTokenizer | |||||
>>> bert_tokenizer = BertTokenizer.from_pretrained('en') | |||||
>>> print(bert_tokenizer.encode('from')) | |||||
>>> print(bert_tokenizer.encode("This is a demo sentence")) | |||||
>>> print(bert_tokenizer.encode(["This", "is", 'a'])) | |||||
:param List[str],str text: 输入的一条认为是一句话。 | |||||
:param bool add_special_tokens: 是否保证句首和句尾是cls和sep。 | |||||
:return: | |||||
""" | |||||
word_pieces = [] | |||||
if isinstance(text, str): | |||||
words = text.split() | |||||
elif isinstance(text, list): | |||||
words = text | |||||
else: | |||||
raise TypeError("Only support str or List[str]") | |||||
for word in words: | |||||
_words = self.basic_tokenizer._tokenize_chinese_chars(word).split() | |||||
tokens = [] | |||||
for word in _words: | |||||
tokens.extend(self.wordpiece_tokenizer.tokenize(word)) | |||||
word_piece_ids = self.convert_tokens_to_ids(tokens) | |||||
word_pieces.extend(word_piece_ids) | |||||
if add_special_tokens: | |||||
if word_pieces[0] != self.cls_index: | |||||
word_pieces.insert(0, self.cls_index) | |||||
if word_pieces[-1] != self.sep_index: | |||||
word_pieces.append(self.sep_index) | |||||
return word_pieces |
@@ -0,0 +1,757 @@ | |||||
r"""undocumented | |||||
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | |||||
有用,也请引用一下他们。 | |||||
""" | |||||
__all__ = [ | |||||
'GPT2Tokenizer' | |||||
] | |||||
from functools import lru_cache | |||||
import json | |||||
import regex as re | |||||
import itertools | |||||
from ...io.file_utils import _get_gpt2_dir | |||||
from ...core import logger | |||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||||
import os | |||||
PRETRAINED_GPT2_MODEL_DIR = PRETRAINED_BERT_MODEL_DIR = { | |||||
'en-small': 'gpt2-small.zip', | |||||
'en-median': 'gpt2-medium.zip', | |||||
'en': 'gpt2-medium.zip' | |||||
} | |||||
@lru_cache() | |||||
def bytes_to_unicode(): | |||||
""" | |||||
Returns list of utf-8 byte and a mapping to unicode strings. | |||||
We specifically avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
The reversible bpe codes work on unicode strings. | |||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
""" | |||||
bs = ( | |||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) | |||||
) | |||||
cs = bs[:] | |||||
n = 0 | |||||
for b in range(2 ** 8): | |||||
if b not in bs: | |||||
bs.append(b) | |||||
cs.append(2 ** 8 + n) | |||||
n += 1 | |||||
cs = [chr(n) for n in cs] | |||||
return dict(zip(bs, cs)) | |||||
def get_pairs(word): | |||||
"""Return set of symbol pairs in a word. | |||||
Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
""" | |||||
pairs = set() | |||||
prev_char = word[0] | |||||
for char in word[1:]: | |||||
pairs.add((prev_char, char)) | |||||
prev_char = char | |||||
return pairs | |||||
VOCAB_FILES_NAMES = { | |||||
"vocab_file": "vocab.json", | |||||
"merges_file": "merges.txt", | |||||
} | |||||
PRETRAINED_VOCAB_FILES_MAP = { | |||||
"vocab_file": { | |||||
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", | |||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", | |||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", | |||||
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json", | |||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json", | |||||
}, | |||||
"merges_file": { | |||||
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", | |||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", | |||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", | |||||
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt", | |||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt", | |||||
}, | |||||
} | |||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||||
"en-small": 1024, | |||||
'en': 1024, | |||||
"en-medium": 1024, | |||||
"en-large": 1024, | |||||
"en-xl": 1024, | |||||
"en-distilgpt2": 1024, | |||||
} | |||||
PATTERN = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | |||||
def gpt2_tokenize(text, add_prefix_space=True): | |||||
""" | |||||
:param str text: | |||||
:param bool add_prefix_space: 是否在句子前面加上space,如果加上才能保证与GPT2训练时一致 | |||||
:return: [] | |||||
""" | |||||
if text is '': | |||||
return [] | |||||
if add_prefix_space: | |||||
text = ' ' + text | |||||
tokens = [] | |||||
for token in re.findall(PATTERN, text): | |||||
tokens.append(token) | |||||
return tokens | |||||
class GPT2Tokenizer: | |||||
""" | |||||
GPT-2 BPE tokenizer. Peculiarities: | |||||
- Byte-level Byte-Pair-Encoding | |||||
- Requires a space to start the input string => the encoding and tokenize methods should be called with the | |||||
``add_prefix_space`` flag set to ``True``. | |||||
Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve | |||||
the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"` | |||||
""" | |||||
vocab_files_names = VOCAB_FILES_NAMES | |||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||||
SPECIAL_TOKENS_ATTRIBUTES = [ | |||||
"bos_token", | |||||
"eos_token", | |||||
"unk_token", | |||||
"pad_token", | |||||
"cls_token", | |||||
"mask_token", | |||||
"sep_token", | |||||
] | |||||
padding_side = "right" | |||||
def __init__( | |||||
self, | |||||
vocab_file, | |||||
merges_file, | |||||
errors="replace", | |||||
unk_token="<|endoftext|>", | |||||
bos_token="<|endoftext|>", | |||||
eos_token="<|endoftext|>", | |||||
**kwargs | |||||
): | |||||
self._bos_token = None | |||||
self._eos_token = None | |||||
self._unk_token = None | |||||
self._sep_token = None | |||||
self._pad_token = None | |||||
self._cls_token = None | |||||
self._mask_token = None | |||||
self._pad_token_type_id = 0 | |||||
self.bos_token = bos_token | |||||
self.eos_token = eos_token | |||||
self.unk_token = unk_token | |||||
self.max_len = int(1e12) | |||||
self.padding_side = kwargs.pop("padding_side", self.padding_side) | |||||
self.added_tokens_encoder = {} | |||||
self.unique_added_tokens_encoder = set() | |||||
self.added_tokens_decoder = {} | |||||
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) | |||||
self.init_inputs = () | |||||
self.init_kwargs = {} | |||||
for key, value in kwargs.items(): | |||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES: | |||||
if key == "additional_special_tokens": | |||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) | |||||
else: | |||||
assert isinstance(value, str) | |||||
setattr(self, key, value) | |||||
self.max_len_single_sentence = ( | |||||
self.max_len | |||||
) # no default special tokens - you can update this value if you add special tokens | |||||
self.max_len_sentences_pair = ( | |||||
self.max_len | |||||
) # no default special tokens - you can update this value if you add special tokens | |||||
with open(vocab_file, encoding="utf-8") as vocab_handle: | |||||
self.encoder = json.load(vocab_handle) | |||||
self.decoder = {v: k for k, v in self.encoder.items()} | |||||
self.errors = errors # how to handle errors in decoding | |||||
self.byte_encoder = bytes_to_unicode() | |||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
with open(merges_file, encoding="utf-8") as merges_handle: | |||||
bpe_merges = merges_handle.read().split("\n")[1:-1] | |||||
bpe_merges = [tuple(merge.split()) for merge in bpe_merges] | |||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||||
self.cache = {} | |||||
def _reinit_on_new_vocab(self, vocab): | |||||
self.encoder = {k:v for k,v in vocab.items()} | |||||
self.decoder = {v:k for k,v in vocab.items()} | |||||
self.cache = {} | |||||
@property | |||||
def bos_token(self): | |||||
""" Beginning of sentence token (string). Log an error if used while not having been set. """ | |||||
if self._bos_token is None: | |||||
logger.error("Using bos_token, but it is not set yet.") | |||||
return self._bos_token | |||||
@property | |||||
def eos_token(self): | |||||
""" End of sentence token (string). Log an error if used while not having been set. """ | |||||
if self._eos_token is None: | |||||
logger.error("Using eos_token, but it is not set yet.") | |||||
return self._eos_token | |||||
@property | |||||
def unk_token(self): | |||||
""" Unknown token (string). Log an error if used while not having been set. """ | |||||
if self._unk_token is None: | |||||
logger.error("Using unk_token, but it is not set yet.") | |||||
return self._unk_token | |||||
@property | |||||
def pad_token(self): | |||||
""" Padding token (string). Log an error if used while not having been set. """ | |||||
if self._pad_token is None: | |||||
logger.error("Using pad_token, but it is not set yet.") | |||||
return self._pad_token | |||||
@property | |||||
def cls_token(self): | |||||
""" Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ | |||||
if self._cls_token is None: | |||||
logger.error("Using cls_token, but it is not set yet.") | |||||
return self._cls_token | |||||
@property | |||||
def sep_token(self): | |||||
if self._sep_token is None: | |||||
logger.error("Using sep_token, but it is not set yet.") | |||||
return self._sep_token | |||||
@property | |||||
def mask_token(self): | |||||
""" Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ | |||||
if self._mask_token is None: | |||||
logger.error("Using mask_token, but it is not set yet.") | |||||
return self._mask_token | |||||
@bos_token.setter | |||||
def bos_token(self, value): | |||||
self._bos_token = value | |||||
@eos_token.setter | |||||
def eos_token(self, value): | |||||
self._eos_token = value | |||||
@unk_token.setter | |||||
def unk_token(self, value): | |||||
self._unk_token = value | |||||
@pad_token.setter | |||||
def pad_token(self, value): | |||||
self._pad_token = value | |||||
@cls_token.setter | |||||
def cls_token(self, value): | |||||
self._cls_token = value | |||||
@sep_token.setter | |||||
def sep_token(self, value): | |||||
self._sep_token = value | |||||
@mask_token.setter | |||||
def mask_token(self, value): | |||||
self._mask_token = value | |||||
@property | |||||
def bos_index(self): | |||||
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ | |||||
return self.convert_tokens_to_ids(self.bos_token) | |||||
@property | |||||
def sep_index(self): | |||||
return self.convert_tokens_to_ids(self.sep_token) | |||||
@property | |||||
def eos_index(self): | |||||
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ | |||||
return self.convert_tokens_to_ids(self.eos_token) | |||||
@property | |||||
def unk_index(self): | |||||
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ | |||||
return self.convert_tokens_to_ids(self.unk_token) | |||||
@property | |||||
def pad_index(self): | |||||
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """ | |||||
return self.convert_tokens_to_ids(self.pad_token) | |||||
@property | |||||
def pad_token_type_id(self): | |||||
""" Id of the padding token type in the vocabulary.""" | |||||
return self._pad_token_type_id | |||||
@property | |||||
def cls_index(self): | |||||
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ | |||||
return self.convert_tokens_to_ids(self.cls_token) | |||||
@property | |||||
def mask_index(self): | |||||
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ | |||||
return self.convert_tokens_to_ids(self.mask_token) | |||||
@property | |||||
def vocab_size(self): | |||||
return len(self.encoder) | |||||
def bpe(self, token): | |||||
# 如果token没有找到,会被拆分成字母返回 | |||||
if token in self.cache: | |||||
return self.cache[token] | |||||
word = tuple(token) | |||||
pairs = get_pairs(word) # 如果word是abcd,则((a,b), (b,c), (c, d), (e,f)) | |||||
if not pairs: | |||||
return token | |||||
while True: | |||||
# 首先找到最常的pair | |||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |||||
if bigram not in self.bpe_ranks: | |||||
break | |||||
first, second = bigram | |||||
new_word = [] | |||||
i = 0 | |||||
while i < len(word): | |||||
try: | |||||
j = word.index(first, i) | |||||
except ValueError: | |||||
new_word.extend(word[i:]) | |||||
break | |||||
else: | |||||
new_word.extend(word[i:j]) #最先找的 | |||||
i = j | |||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | |||||
new_word.append(first + second) | |||||
i += 2 | |||||
else: | |||||
new_word.append(word[i]) | |||||
i += 1 | |||||
new_word = tuple(new_word) | |||||
word = new_word | |||||
if len(word) == 1: | |||||
break | |||||
else: | |||||
pairs = get_pairs(word) | |||||
word = " ".join(word) | |||||
self.cache[token] = word | |||||
return word | |||||
def _tokenize(self, text, add_prefix_space=False): | |||||
""" Tokenize a string. | |||||
Args: | |||||
- add_prefix_space (boolean, default False): | |||||
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers. | |||||
""" | |||||
bpe_tokens = [] | |||||
for token in gpt2_tokenize(text, add_prefix_space=add_prefix_space): | |||||
token = "".join( | |||||
self.byte_encoder[b] for b in token.encode("utf-8") | |||||
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) | |||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) | |||||
return bpe_tokens | |||||
def _convert_token_to_id(self, token): | |||||
""" Converts a token (str) in an id using the vocab. """ | |||||
return self.encoder.get(token, self.encoder.get(self.unk_token)) | |||||
def _convert_id_to_token(self, index): | |||||
"""Converts an index (integer) in a token (str) using the vocab.""" | |||||
return self.decoder.get(index) | |||||
def convert_tokens_to_string(self, tokens): | |||||
""" Converts a sequence of tokens (string) in a single string. """ | |||||
text = "".join(tokens) | |||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | |||||
return text | |||||
def save_vocabulary(self, save_directory): | |||||
"""Save the tokenizer vocabulary and merge files to a directory.""" | |||||
if not os.path.isdir(save_directory): | |||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) | |||||
return | |||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) | |||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"]) | |||||
with open(vocab_file, "w", encoding="utf-8") as f: | |||||
f.write(json.dumps(self.encoder, ensure_ascii=False)) | |||||
index = 0 | |||||
with open(merge_file, "w", encoding="utf-8") as writer: | |||||
writer.write("#version: 0.2\n") | |||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): | |||||
if index != token_index: | |||||
logger.warning( | |||||
"Saving vocabulary to {}: BPE merge indices are not consecutive." | |||||
" Please check that the tokenizer is not corrupted!".format(merge_file) | |||||
) | |||||
index = token_index | |||||
writer.write(" ".join(bpe_tokens) + "\n") | |||||
index += 1 | |||||
return vocab_file, merge_file | |||||
@classmethod | |||||
def from_pretrained(cls, model_dir_or_name): | |||||
r""" | |||||
""" | |||||
return cls._from_pretrained(model_dir_or_name) | |||||
# 将它修改一定传入文件夹 | |||||
@classmethod | |||||
def _from_pretrained(cls, model_dir_or_name): | |||||
""" | |||||
:param str model_dir_or_name: 目录或者缩写名 | |||||
:param init_inputs: | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
# 它需要两个文件,第一个是vocab.json,第二个是merge_file? | |||||
model_dir = _get_gpt2_dir(model_dir_or_name) | |||||
# 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin | |||||
tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json') | |||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: | |||||
init_kwargs = json.load(tokenizer_config_handle) | |||||
if 'max_len' not in init_kwargs: | |||||
init_kwargs['max_len'] = 1024 | |||||
# Set max length if needed | |||||
if model_dir_or_name in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: | |||||
# if we're using a pretrained model, ensure the tokenizer | |||||
# wont index sequences longer than the number of positional embeddings | |||||
max_len = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name] | |||||
if max_len is not None and isinstance(max_len, (int, float)): | |||||
init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len) | |||||
# 将vocab, merge加入到init_kwargs中 | |||||
init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, 'vocab.json') | |||||
init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, 'merges.txt') | |||||
init_inputs = init_kwargs.pop("init_inputs", ()) | |||||
# Instantiate tokenizer. | |||||
try: | |||||
tokenizer = cls(*init_inputs, **init_kwargs) | |||||
except OSError: | |||||
OSError( | |||||
"Unable to load vocabulary from file. " | |||||
"Please check that the provided vocabulary is accessible and not corrupted." | |||||
) | |||||
return tokenizer | |||||
def __len__(self): | |||||
""" Size of the full vocabulary with the added tokens """ | |||||
return self.vocab_size + len(self.added_tokens_encoder) | |||||
def tokenize(self, text, add_prefix_space=True): | |||||
""" Converts a string in a sequence of tokens (string), using the tokenizer. | |||||
Split in words for word-based vocabulary or sub-words for sub-word-based | |||||
vocabularies (BPE/SentencePieces/WordPieces). | |||||
Take care of added tokens. | |||||
Args: | |||||
- text: The sequence to be encoded. | |||||
- add_prefix_space (boolean, default True): | |||||
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers. | |||||
""" | |||||
all_special_tokens = self.all_special_tokens | |||||
def lowercase_text(t): | |||||
# convert non-special tokens to lowercase | |||||
escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens] | |||||
pattern = r'(' + r'|'.join(escaped_special_toks) + r')|' + \ | |||||
r'(.+?)' | |||||
return re.sub( | |||||
pattern, | |||||
lambda m: m.groups()[0] or m.groups()[1].lower(), | |||||
t) | |||||
if self.init_kwargs.get('do_lower_case', False): | |||||
text = lowercase_text(text) | |||||
def split_on_token(tok, text): | |||||
result = [] | |||||
split_text = text.split(tok) | |||||
for i, sub_text in enumerate(split_text): | |||||
sub_text = sub_text.strip() | |||||
if i == 0 and not sub_text: | |||||
result += [tok] | |||||
elif i == len(split_text) - 1: | |||||
if sub_text: | |||||
result += [sub_text] | |||||
else: | |||||
pass | |||||
else: | |||||
if sub_text: | |||||
result += [sub_text] | |||||
result += [tok] | |||||
return result | |||||
def split_on_tokens(tok_list, text): | |||||
if not text.strip(): | |||||
return [] | |||||
if not tok_list: | |||||
return self._tokenize(text, add_prefix_space=add_prefix_space) | |||||
tokenized_text = [] | |||||
text_list = [text] | |||||
for tok in tok_list: | |||||
tokenized_text = [] | |||||
for sub_text in text_list: | |||||
if sub_text not in self.added_tokens_encoder \ | |||||
and sub_text not in all_special_tokens: | |||||
tokenized_text += split_on_token(tok, sub_text) | |||||
else: | |||||
tokenized_text += [sub_text] | |||||
text_list = tokenized_text | |||||
return list(itertools.chain.from_iterable((self._tokenize(token, add_prefix_space=add_prefix_space) if token not \ | |||||
in self.added_tokens_encoder and token not in all_special_tokens \ | |||||
else [token] for token in tokenized_text))) | |||||
added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens | |||||
tokenized_text = split_on_tokens(added_tokens, text) | |||||
return tokenized_text | |||||
def convert_tokens_to_ids(self, tokens): | |||||
""" Converts a single token, or a sequence of tokens, (str) in a single integer id | |||||
(resp. a sequence of ids), using the vocabulary. | |||||
""" | |||||
if tokens is None: | |||||
return None | |||||
if isinstance(tokens, str): | |||||
return self._convert_token_to_id_with_added_voc(tokens) | |||||
ids = [] | |||||
for token in tokens: | |||||
ids.append(self._convert_token_to_id_with_added_voc(token)) | |||||
return ids | |||||
def _convert_token_to_id_with_added_voc(self, token): | |||||
if token is None: | |||||
return None | |||||
if token in self.added_tokens_encoder: | |||||
return self.added_tokens_encoder[token] | |||||
return self._convert_token_to_id(token) | |||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | |||||
""" Converts a single index or a sequence of indices (integers) in a token " | |||||
(resp.) a sequence of tokens (str), using the vocabulary and added tokens. | |||||
Args: | |||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False | |||||
""" | |||||
if isinstance(ids, int): | |||||
return self._convert_id_to_token(ids) | |||||
tokens = [] | |||||
for index in ids: | |||||
index = int(index) | |||||
if skip_special_tokens and index in self.all_special_ids: | |||||
continue | |||||
tokens.append(self._convert_id_to_token(index)) | |||||
return tokens | |||||
def convert_id_to_tokens(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): | |||||
""" | |||||
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary | |||||
with options to remove special tokens and clean up tokenization spaces. | |||||
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. | |||||
Args: | |||||
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. | |||||
skip_special_tokens: if set to True, will replace special tokens. | |||||
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. | |||||
""" | |||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) | |||||
# To avoid mixing byte-level and unicode for byte-level BPT | |||||
# we need to build string separatly for added tokens and byte-level tokens | |||||
# cf. https://github.com/huggingface/transformers/issues/1133 | |||||
sub_texts = [] | |||||
current_sub_text = [] | |||||
for token in filtered_tokens: | |||||
if skip_special_tokens and token in self.all_special_ids: | |||||
continue | |||||
if token in self.added_tokens_encoder: | |||||
if current_sub_text: | |||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) | |||||
current_sub_text = [] | |||||
sub_texts.append(token) | |||||
else: | |||||
current_sub_text.append(token) | |||||
if current_sub_text: | |||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) | |||||
text = " ".join(sub_texts) | |||||
if clean_up_tokenization_spaces: | |||||
clean_text = self.clean_up_tokenization(text) | |||||
return clean_text | |||||
else: | |||||
return text | |||||
@property | |||||
def special_tokens_map(self): | |||||
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their | |||||
values ('<unk>', '<cls>'...) | |||||
""" | |||||
set_attr = {} | |||||
for attr in self.SPECIAL_TOKENS_ATTRIBUTES: | |||||
attr_value = getattr(self, "_" + attr) | |||||
if attr_value: | |||||
set_attr[attr] = attr_value | |||||
return set_attr | |||||
@property | |||||
def all_special_tokens(self): | |||||
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes | |||||
(cls_token, unk_token...). | |||||
""" | |||||
all_toks = [] | |||||
set_attr = self.special_tokens_map | |||||
for attr_value in set_attr.values(): | |||||
all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value]) | |||||
all_toks = list(set(all_toks)) | |||||
return all_toks | |||||
@property | |||||
def all_special_ids(self): | |||||
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to | |||||
class attributes (cls_token, unk_token...). | |||||
""" | |||||
all_toks = self.all_special_tokens | |||||
all_ids = self.convert_tokens_to_ids(all_toks) | |||||
return all_ids | |||||
@staticmethod | |||||
def clean_up_tokenization(out_string): | |||||
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. | |||||
""" | |||||
out_string = ( | |||||
out_string.replace(" .", ".") | |||||
.replace(" ?", "?") | |||||
.replace(" !", "!") | |||||
.replace(" ,", ",") | |||||
.replace(" ' ", "'") | |||||
.replace(" n't", "n't") | |||||
.replace(" 'm", "'m") | |||||
.replace(" do not", " don't") | |||||
.replace(" 's", "'s") | |||||
.replace(" 've", "'ve") | |||||
.replace(" 're", "'re") | |||||
) | |||||
return out_string | |||||
def encode(self, text, add_special_tokens=False, add_prefix_space=True): | |||||
""" | |||||
给定text输入将数据encode为index的形式。 | |||||
Example:: | |||||
>>> from fastNLP.modules import GPT2Tokenizer | |||||
>>> gpt2_tokenizer = GPT2Tokenizer.from_pretrained('en') | |||||
>>> print(gpt2_tokenizer.encode('from')) | |||||
>>> print(gpt2_tokenizer.encode("This is a demo sentence")) | |||||
>>> print(gpt2_tokenizer.encode(["This", "is", 'a'])) | |||||
:param List[str],str text: 输入的一条认为是一句话。 | |||||
:param bool add_special_tokens: 是否保证句首和句尾是cls和sep。GPT2没有cls和sep这一说 | |||||
:return: | |||||
""" | |||||
if isinstance(text, str): | |||||
words = text.split() | |||||
elif isinstance(text, list): | |||||
words = text | |||||
else: | |||||
raise TypeError("Only support str or List[str]") | |||||
word_pieces = [] | |||||
for word in words: | |||||
tokens = self.tokenize(word, add_prefix_space=add_prefix_space) | |||||
word_piece_ids = self.convert_tokens_to_ids(tokens) | |||||
word_pieces.extend(word_piece_ids) | |||||
if add_special_tokens: | |||||
if self._cls_token is not None and word_pieces[0] != self.cls_index: | |||||
word_pieces.insert(0, self.cls_index) | |||||
if self._sep_token is not None and word_pieces[-1] != self.sep_index: | |||||
word_pieces.append(self.eos_index) | |||||
return word_pieces | |||||
def get_used_merge_pair_vocab(self, token): | |||||
# 如果token没有找到,会被拆分成字母返回 TODO need comment | |||||
used_pairs = {} | |||||
word = tuple(token) | |||||
pairs = get_pairs(word) # 如果word是abcd,则((a,b), (b,c), (c, d), (e,f)) | |||||
if not pairs: | |||||
return token, used_pairs | |||||
while True: | |||||
# 首先找到最常的pair | |||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |||||
if bigram not in self.bpe_ranks: | |||||
break | |||||
used_pairs[bigram] = self.bpe_ranks[bigram] | |||||
first, second = bigram | |||||
new_word = [] | |||||
i = 0 | |||||
while i < len(word): | |||||
try: | |||||
j = word.index(first, i) | |||||
except ValueError: | |||||
new_word.extend(word[i:]) | |||||
break | |||||
else: | |||||
new_word.extend(word[i:j]) #最先找的 | |||||
i = j | |||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | |||||
new_word.append(first + second) | |||||
i += 2 | |||||
else: | |||||
new_word.append(word[i]) | |||||
i += 1 | |||||
new_word = tuple(new_word) | |||||
word = new_word | |||||
if len(word) == 1: | |||||
break | |||||
else: | |||||
pairs = get_pairs(word) | |||||
word = " ".join(word) | |||||
return word, used_pairs |
@@ -0,0 +1,102 @@ | |||||
r""" | |||||
""" | |||||
__all__ = [ | |||||
"RobertaTokenizer" | |||||
] | |||||
import json | |||||
from .gpt2_tokenizer import GPT2Tokenizer | |||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||||
from ...io.file_utils import _get_roberta_dir | |||||
PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = { | |||||
"roberta-base": 512, | |||||
"roberta-large": 512, | |||||
"roberta-large-mnli": 512, | |||||
"distilroberta-base": 512, | |||||
"roberta-base-openai-detector": 512, | |||||
"roberta-large-openai-detector": 512, | |||||
} | |||||
class RobertaTokenizer(GPT2Tokenizer): | |||||
vocab_files_names = { | |||||
"vocab_file": "vocab.json", | |||||
"merges_file": "merges.txt", | |||||
} | |||||
def __init__( | |||||
self, | |||||
vocab_file, | |||||
merges_file, | |||||
errors="replace", | |||||
bos_token="<s>", | |||||
eos_token="</s>", | |||||
sep_token="</s>", | |||||
cls_token="<s>", | |||||
unk_token="<unk>", | |||||
pad_token="<pad>", | |||||
mask_token="<mask>", | |||||
**kwargs | |||||
): | |||||
super().__init__( | |||||
vocab_file=vocab_file, | |||||
merges_file=merges_file, | |||||
errors=errors, | |||||
bos_token=bos_token, | |||||
eos_token=eos_token, | |||||
unk_token=unk_token, | |||||
sep_token=sep_token, | |||||
cls_token=cls_token, | |||||
pad_token=pad_token, | |||||
mask_token=mask_token, | |||||
**kwargs, | |||||
) | |||||
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens | |||||
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens | |||||
@classmethod | |||||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||||
""" | |||||
:param str model_dir_or_name: 目录或者缩写名 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
# 它需要两个文件,第一个是vocab.json,第二个是merge_file? | |||||
model_dir = _get_roberta_dir(model_dir_or_name) | |||||
# 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin | |||||
tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json') | |||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: | |||||
init_kwargs = json.load(tokenizer_config_handle) | |||||
# Set max length if needed | |||||
if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES: | |||||
# if we're using a pretrained model, ensure the tokenizer | |||||
# wont index sequences longer than the number of positional embeddings | |||||
max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name] | |||||
if max_len is not None and isinstance(max_len, (int, float)): | |||||
init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len) | |||||
# 将vocab, merge加入到init_kwargs中 | |||||
if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表 | |||||
init_kwargs['vocab_file'] = kwargs['vocab_file'] | |||||
else: | |||||
init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['vocab_file']) | |||||
init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['merges_file']) | |||||
init_inputs = init_kwargs.pop("init_inputs", ()) | |||||
# Instantiate tokenizer. | |||||
try: | |||||
tokenizer = cls(*init_inputs, **init_kwargs) | |||||
except OSError: | |||||
OSError( | |||||
"Unable to load vocabulary from file. " | |||||
"Please check that the provided vocabulary is accessible and not corrupted." | |||||
) | |||||
return tokenizer | |||||
@@ -8,7 +8,6 @@ __all__ = [ | |||||
"summary" | "summary" | ||||
] | ] | ||||
import os | |||||
from functools import reduce | from functools import reduce | ||||
import torch | import torch | ||||
@@ -133,18 +132,3 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||||
nn.functional.dropout(mask_x, p=drop_p, | nn.functional.dropout(mask_x, p=drop_p, | ||||
training=False, inplace=True) | training=False, inplace=True) | ||||
return mask_x | return mask_x | ||||
def _get_file_name_base_on_postfix(dir_path, postfix): | |||||
r""" | |||||
在dir_path中寻找后缀为postfix的文件. | |||||
:param dir_path: str, 文件夹 | |||||
:param postfix: 形如".bin", ".json"等 | |||||
:return: str,文件的路径 | |||||
""" | |||||
files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) | |||||
if len(files) == 0: | |||||
raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}") | |||||
elif len(files) > 1: | |||||
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") | |||||
return os.path.join(dir_path, files[0]) |
@@ -687,16 +687,16 @@ def main(): | |||||
if hps.mode == 'train': | if hps.mode == 'train': | ||||
trainset = dataInfo.datasets["train"] | trainset = dataInfo.datasets["train"] | ||||
train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT) | ||||
train_batch = DataSetIter(batch_size=hps.batch_size, dataset=trainset, sampler=train_sampler) | |||||
train_batch = DataSetIter(dataset=trainset, batch_size=hps.batch_size, sampler=train_sampler) | |||||
validset = dataInfo.datasets["valid"] | validset = dataInfo.datasets["valid"] | ||||
validset.set_input("text", "summary") | validset.set_input("text", "summary") | ||||
valid_batch = DataSetIter(batch_size=hps.batch_size, dataset=validset) | |||||
valid_batch = DataSetIter(dataset=validset, batch_size=hps.batch_size) | |||||
setup_training(model, train_batch, valid_batch, hps) | setup_training(model, train_batch, valid_batch, hps) | ||||
elif hps.mode == 'test': | elif hps.mode == 'test': | ||||
logger.info("[INFO] Decoding...") | logger.info("[INFO] Decoding...") | ||||
testset = dataInfo.datasets["test"] | testset = dataInfo.datasets["test"] | ||||
testset.set_input("text", "summary") | testset.set_input("text", "summary") | ||||
test_batch = DataSetIter(batch_size=hps.batch_size, dataset=testset) | |||||
test_batch = DataSetIter(dataset=testset, batch_size=hps.batch_size) | |||||
run_test(model, test_batch, hps, limited=hps.limited) | run_test(model, test_batch, hps, limited=hps.limited) | ||||
else: | else: | ||||
logger.error("The 'mode' flag must be one of train/eval/test") | logger.error("The 'mode' flag must be one of train/eval/test") | ||||
@@ -406,18 +406,8 @@ if not options.test: | |||||
logger.info("Number training instances: {}".format(len(train_set))) | logger.info("Number training instances: {}".format(len(train_set))) | ||||
logger.info("Number dev instances: {}".format(len(dev_set))) | logger.info("Number dev instances: {}".format(len(dev_set))) | ||||
train_batch = DataSetIter( | |||||
batch_size=options.batch_size, | |||||
dataset=train_set, | |||||
sampler=train_sampler, | |||||
num_workers=4, | |||||
) | |||||
dev_batch = DataSetIter( | |||||
batch_size=options.batch_size, | |||||
dataset=dev_set, | |||||
sampler=dev_sampler, | |||||
num_workers=4, | |||||
) | |||||
train_batch = DataSetIter(dataset=train_set, batch_size=options.batch_size, sampler=train_sampler, num_workers=4) | |||||
dev_batch = DataSetIter(dataset=dev_set, batch_size=options.batch_size, sampler=dev_sampler, num_workers=4) | |||||
best_f1 = 0.0 | best_f1 = 0.0 | ||||
for epoch in range(int(options.num_epochs)): | for epoch in range(int(options.num_epochs)): | ||||
@@ -279,7 +279,7 @@ class TestCase1(unittest.TestCase): | |||||
data.add_collate_fn(concat_collate_fn) | data.add_collate_fn(concat_collate_fn) | ||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||||
for batch_x, batch_y in DataSetIter(data, batch_size=2, sampler=SequentialSampler()): | |||||
print("batch_x:", batch_x) | print("batch_x:", batch_x) | ||||
print("batch_y:", batch_y) | print("batch_y:", batch_y) | ||||
# batch_x: {'x': tensor([[0, 1, 3, 0], | # batch_x: {'x': tensor([[0, 1, 3, 0], | ||||
@@ -302,7 +302,7 @@ class TestCase1(unittest.TestCase): | |||||
return b_x, b_y | return b_x, b_y | ||||
data.delete_collate_fn() # 删除之前的collate_fn | data.delete_collate_fn() # 删除之前的collate_fn | ||||
data.add_collate_fn(ConCollateFn(max_len=3)) | data.add_collate_fn(ConCollateFn(max_len=3)) | ||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||||
for batch_x, batch_y in DataSetIter(data, batch_size=2, sampler=SequentialSampler()): | |||||
print("batch_x:", batch_x) | print("batch_x:", batch_x) | ||||
print("batch_y:", batch_y) | print("batch_y:", batch_y) | ||||
# batch_x: {'x': tensor([[0, 1, 3], | # batch_x: {'x': tensor([[0, 1, 3], | ||||
@@ -362,10 +362,9 @@ class TestCase1(unittest.TestCase): | |||||
batch_sampler = BatchSampler(ds) | batch_sampler = BatchSampler(ds) | ||||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None, collate_fn=None, | |||||
batch_sampler=batch_sampler) | |||||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||||
batch_sampler=batch_sampler) | |||||
num_samples = [len(ds)//2, len(ds)-len(ds)//2] | num_samples = [len(ds)//2, len(ds)-len(ds)//2] | ||||
for idx, (batch_x, batch_y) in enumerate(data_iter): | for idx, (batch_x, batch_y) in enumerate(data_iter): | ||||
self.assertEqual(num_samples[idx], len(batch_x['1'])) | self.assertEqual(num_samples[idx], len(batch_x['1'])) | ||||
@@ -264,7 +264,6 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertEqual(ans.content, [[5, 6]] * 10) | self.assertEqual(ans.content, [[5, 6]] * 10) | ||||
def test_add_null(self): | def test_add_null(self): | ||||
# TODO test failed because 'fastNLP\core\field.py:143: RuntimeError' | |||||
ds = DataSet() | ds = DataSet() | ||||
with self.assertRaises(RuntimeError) as RE: | with self.assertRaises(RuntimeError) as RE: | ||||
ds.add_field('test', []) | ds.add_field('test', []) | ||||
@@ -45,7 +45,6 @@ def _convert_res_to_fastnlp_res(metric_result): | |||||
return allen_result | return allen_result | ||||
class TestConfusionMatrixMetric(unittest.TestCase): | class TestConfusionMatrixMetric(unittest.TestCase): | ||||
def test_ConfusionMatrixMetric1(self): | def test_ConfusionMatrixMetric1(self): | ||||
pred_dict = {"pred": torch.zeros(4,3)} | pred_dict = {"pred": torch.zeros(4,3)} | ||||
@@ -57,21 +56,17 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
def test_ConfusionMatrixMetric2(self): | def test_ConfusionMatrixMetric2(self): | ||||
# (2) with corrupted size | # (2) with corrupted size | ||||
try: | |||||
with self.assertRaises(Exception): | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | metric(pred_dict=pred_dict, target_dict=target_dict, ) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
print("No exception catches.") | |||||
def test_ConfusionMatrixMetric3(self): | def test_ConfusionMatrixMetric3(self): | ||||
# (3) the second batch is corrupted size | # (3) the second batch is corrupted size | ||||
try: | |||||
with self.assertRaises(Exception): | |||||
metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4, 3)} | target_dict = {'target': torch.zeros(4, 3)} | ||||
@@ -82,10 +77,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
assert(True, False), "No exception catches." | |||||
def test_ConfusionMatrixMetric4(self): | def test_ConfusionMatrixMetric4(self): | ||||
# (4) check reset | # (4) check reset | ||||
@@ -99,16 +91,12 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
def test_ConfusionMatrixMetric5(self): | def test_ConfusionMatrixMetric5(self): | ||||
# (5) check numpy array is not acceptable | # (5) check numpy array is not acceptable | ||||
try: | |||||
with self.assertRaises(Exception): | |||||
metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
pred_dict = {"pred": np.zeros((4, 3, 2))} | pred_dict = {"pred": np.zeros((4, 3, 2))} | ||||
target_dict = {'target': np.zeros((4, 3))} | target_dict = {'target': np.zeros((4, 3))} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_ConfusionMatrixMetric6(self): | def test_ConfusionMatrixMetric6(self): | ||||
# (6) check map, match | # (6) check map, match | ||||
metric = ConfusionMatrixMetric(pred='predictions', target='targets') | metric = ConfusionMatrixMetric(pred='predictions', target='targets') | ||||
@@ -119,29 +107,20 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
print(res) | print(res) | ||||
def test_ConfusionMatrixMetric7(self): | def test_ConfusionMatrixMetric7(self): | ||||
# (7) check map, include unused | |||||
try: | |||||
metric = ConfusionMatrixMetric(pred='prediction', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
# (7) check map, include unused | |||||
metric = ConfusionMatrixMetric(pred='prediction', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
def test_ConfusionMatrixMetric8(self): | def test_ConfusionMatrixMetric8(self): | ||||
# (8) check _fast_metric | |||||
try: | |||||
# (8) check _fast_metric | |||||
with self.assertRaises(Exception): | |||||
metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | ||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_duplicate(self): | def test_duplicate(self): | ||||
# 0.4.1的潜在bug,不能出现形参重复的情况 | # 0.4.1的潜在bug,不能出现形参重复的情况 | ||||
@@ -151,7 +130,6 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_seq_len(self): | def test_seq_len(self): | ||||
N = 256 | N = 256 | ||||
seq_len = torch.zeros(N).long() | seq_len = torch.zeros(N).long() | ||||
@@ -177,8 +155,6 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
class TestAccuracyMetric(unittest.TestCase): | class TestAccuracyMetric(unittest.TestCase): | ||||
def test_AccuracyMetric1(self): | def test_AccuracyMetric1(self): | ||||
# (1) only input, targets passed | # (1) only input, targets passed | ||||
@@ -0,0 +1 @@ | |||||
{"architectures": ["GPT2LMHeadModel"], "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, "n_ctx": 20, "n_embd": 16, "n_head": 4, "n_layer": 2, "n_positions": 20, "vocab_size": 64} |
@@ -0,0 +1,39 @@ | |||||
#version: small | |||||
a b | |||||
c e | |||||
e l | |||||
e m | |||||
e n | |||||
en ce | |||||
en t | |||||
h e | |||||
he r | |||||
i s | |||||
o c | |||||
o d | |||||
o t | |||||
ot her | |||||
x t | |||||
Ġ T | |||||
Ġ a | |||||
Ġ d | |||||
Ġ is | |||||
Ġ m | |||||
Ġ s | |||||
Ġ t | |||||
Ġ v | |||||
ĠT h | |||||
ĠTh is | |||||
Ġa n | |||||
Ġan other | |||||
Ġd em | |||||
Ġdem o | |||||
Ġm od | |||||
Ġmod el | |||||
Ġs ent | |||||
Ġsent ence | |||||
Ġt e | |||||
Ġt h | |||||
Ġte xt | |||||
Ġth is | |||||
Ġv oc |
@@ -0,0 +1 @@ | |||||
{"\u0120This": 0, "\u0120is": 1, "\u0120a": 2, "\u0120demo": 3, "\u0120sentence": 4, "\u0120another": 5, "\u0120this": 6, "\u0120text": 7, "a": 8, "\u0120model": 9, "\u0120voc": 10, "ab": 11, "<|endoftext|>": 12, "A": 13, "B": 14, "C": 15, "D": 16, "E": 17, "F": 18, "G": 19, "H": 20, "I": 21, "J": 22, "K": 23, "L": 24, "M": 25, "N": 26, "O": 27, "P": 28, "Q": 29, "R": 30, "S": 31, "T": 32, "U": 33, "V": 34, "W": 35, "X": 36, "Y": 37, "Z": 38, "b": 39, "c": 40, "d": 41, "e": 42, "f": 43, "g": 44, "h": 45, "i": 46, "j": 47, "k": 48, "l": 49, "m": 50, "n": 51, "o": 52, "p": 53, "q": 54, "r": 55, "s": 56, "t": 57, "u": 58, "v": 59, "w": 60, "x": 61, "y": 62, "z": 63} |
@@ -0,0 +1 @@ | |||||
{"architectures": ["RobertaForMaskedLM"], "attention_probs_dropout_prob": 0.1, "finetuning_task": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 16, "initializer_range": 0.02, "intermediate_size": 20, "layer_norm_eps": 1e-05, "max_position_embeddings": 20, "num_attention_heads": 4, "num_hidden_layers": 2, "num_labels": 2, "output_attentions": false, "output_hidden_states": false, "torchscript": false, "type_vocab_size": 1, "vocab_size": 68} |
@@ -0,0 +1,39 @@ | |||||
#version: tiny | |||||
a b | |||||
c e | |||||
e l | |||||
e m | |||||
e n | |||||
en ce | |||||
en t | |||||
h e | |||||
he r | |||||
i s | |||||
o c | |||||
o d | |||||
o t | |||||
ot her | |||||
x t | |||||
Ġ T | |||||
Ġ a | |||||
Ġ d | |||||
Ġ is | |||||
Ġ m | |||||
Ġ s | |||||
Ġ t | |||||
Ġ v | |||||
ĠT h | |||||
ĠTh is | |||||
Ġa n | |||||
Ġan other | |||||
Ġd em | |||||
Ġdem o | |||||
Ġm od | |||||
Ġmod el | |||||
Ġs ent | |||||
Ġsent ence | |||||
Ġt e | |||||
Ġt h | |||||
Ġte xt | |||||
Ġth is | |||||
Ġv oc |
@@ -0,0 +1 @@ | |||||
{"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "<mask>": 4, "A": 5, "B": 6, "C": 7, "D": 8, "E": 9, "F": 10, "G": 11, "H": 12, "I": 13, "J": 14, "K": 15, "L": 16, "M": 17, "N": 18, "O": 19, "P": 20, "Q": 21, "R": 22, "S": 23, "T": 24, "U": 25, "V": 26, "W": 27, "X": 28, "Y": 29, "Z": 30, "a": 31, "b": 32, "c": 33, "d": 34, "e": 35, "f": 36, "g": 37, "h": 38, "i": 39, "j": 40, "k": 41, "l": 42, "m": 43, "n": 44, "o": 45, "p": 46, "q": 47, "r": 48, "s": 49, "t": 50, "u": 51, "v": 52, "w": 53, "x": 54, "y": 55, "z": 56, "\u0120This": 57, "\u0120is": 58, "\u0120a": 59, "\u0120demo": 60, "\u0120sentence": 61, "\u0120another": 62, "\u0120this": 63, "\u0120text": 64, "\u0120model": 65, "\u0120voc": 66, "ab": 67} |
@@ -50,9 +50,11 @@ class TestBertEmbedding(unittest.TestCase): | |||||
# 自动截断而不报错 | # 自动截断而不报错 | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | ||||
only_use_pretrain_bpe=True, auto_truncate=True) | only_use_pretrain_bpe=True, auto_truncate=True) | ||||
words = torch.LongTensor([[2, 3, 4, 0]*129]) | |||||
words = torch.LongTensor([[2, 3, 4, 1]*10, | |||||
[2, 3]+[0]*38]) | |||||
result = embed(words) | result = embed(words) | ||||
self.assertEqual(result.size(), (1, 516, 16)) | |||||
self.assertEqual(result.size(), (2, 40, 16)) | |||||
def test_bert_embedding_2(self): | def test_bert_embedding_2(self): | ||||
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | # 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | ||||
@@ -0,0 +1,268 @@ | |||||
import unittest | |||||
import torch | |||||
import os | |||||
from fastNLP.modules.tokenizer.gpt2_tokenizer import GPT2Tokenizer | |||||
from fastNLP.embeddings import GPT2WordPieceEncoder, GPT2Embedding | |||||
from fastNLP import DataSet, Vocabulary | |||||
class TestGPT2Embedding(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
def test_download(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = GPT2Embedding(vocab, model_dir_or_name='en') | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
print(embed(words).size()) | |||||
for pool_method in ['first', 'last', 'max', 'avg']: | |||||
embed = GPT2Embedding(vocab, model_dir_or_name='en', pool_method=pool_method) | |||||
print(embed(words).size()) | |||||
def test_gpt2_embedding(self): | |||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
vocab = Vocabulary().add_word_lst("this is a texta sentence".split()) | |||||
embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1) | |||||
requires_grad = embed.requires_grad | |||||
embed.requires_grad = not requires_grad | |||||
embed.train() | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1, | |||||
only_use_pretrain_bpe=False, language_model=True) | |||||
embed.eval() | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
embed.get_lm_loss() | |||||
vocab.add_word("NotInGpt2") | |||||
embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1, | |||||
only_use_pretrain_bpe=False, auto_truncate=True, min_freq=1) | |||||
words = torch.LongTensor([[2, 3, 4, 0]*20]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 80, 16)) | |||||
def test_gpt2_ebembedding_2(self): | |||||
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | |||||
Embedding = GPT2Embedding | |||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
vocab = Vocabulary().add_word_lst("this is a texta and".split()) | |||||
embed1 = Embedding(vocab, model_dir_or_name=weight_path,layers=list(range(3)), | |||||
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | |||||
# embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) | |||||
embed2 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | |||||
only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) | |||||
# embed_bpe_vocab_size = num_word # 排除NotInBERT | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) | |||||
embed3 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | |||||
only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) | |||||
# embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) | |||||
embed4 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | |||||
only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) | |||||
# embed_bpe_vocab_size = num_word+1 # 新增##a | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) | |||||
# 测试各种情况下以下tensor的值是相等的 | |||||
embed1.eval() | |||||
embed2.eval() | |||||
embed3.eval() | |||||
embed4.eval() | |||||
tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) | |||||
t1 = embed1(tensor) | |||||
t2 = embed2(tensor) | |||||
t3 = embed3(tensor) | |||||
t4 = embed4(tensor) | |||||
self.assertEqual((t1-t2).sum(), 0) | |||||
self.assertEqual((t1-t3).sum(), 0) | |||||
self.assertEqual((t1-t4).sum(), 0) | |||||
def test_gpt2_tokenizer(self): | |||||
from fastNLP.modules.tokenizer import GPT2Tokenizer | |||||
tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2') | |||||
print(tokenizer.encode("this is a texta a sentence")) | |||||
print(tokenizer.encode('this is')) | |||||
def test_gpt2_embed_eq_gpt2_piece_encoder(self): | |||||
# 主要检查一下embedding的结果与wordpieceencoder的结果是否一致 | |||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]}) | |||||
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path) | |||||
encoder.eval() | |||||
encoder.index_datasets(ds, field_name='words') | |||||
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) | |||||
word_pieces_res = encoder(word_pieces) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(ds, field_name='words') | |||||
vocab.index_dataset(ds, field_name='words', new_field_name='words') | |||||
ds.set_input('words') | |||||
words = torch.LongTensor(ds['words'].get([0, 1])) | |||||
embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, pool_method='first') | |||||
embed.eval() | |||||
words_res = embed(words) | |||||
# 检查word piece什么的是正常work的 | |||||
self.assertEqual((word_pieces_res[0, :4]-words_res[0, :4]).sum(), 0) | |||||
self.assertEqual((word_pieces_res[0, 5:]-words_res[0, 4:]).sum(), 0) | |||||
self.assertEqual((word_pieces_res[1, :2]-words_res[1, :2]).sum(), 0) | |||||
class TestGPT2WordPieceEncoder(unittest.TestCase): | |||||
@unittest.skipIf(True, "Only for local debugging") | |||||
def test_eq_transformers(self): | |||||
# 测试能否正确得到类似于transformers的结果 | |||||
weight_path = '' | |||||
# tokenizer = transformers.GPT2Tokenizer.from_pretrained(weight_path) | |||||
ds = DataSet({'words': ["this this this a is texta model vocab".split(), 'this is'.split()]}) | |||||
import transformers | |||||
input1 = ' '.join(ds[0]['words']) | |||||
input2 = ' '.join(ds[1]['words']) | |||||
tokenizer = transformers.GPT2Tokenizer.from_pretrained(weight_path) | |||||
idx_list1 = tokenizer.encode(input1) | |||||
idx_list2 = tokenizer.encode(input2) | |||||
pad_value = tokenizer.encode('<|endoftext|>')[0] | |||||
tensor = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(idx_list1), | |||||
torch.LongTensor(idx_list2)], | |||||
batch_first=True, | |||||
padding_value=pad_value) | |||||
gpt2 = transformers.GPT2Model.from_pretrained(weight_path, output_hidden_states=True) | |||||
gpt2.eval() | |||||
tensor = tensor | |||||
output, _, trans_hidden_states = gpt2(tensor, attention_mask=tensor.ne(pad_value)) | |||||
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, layers=list(range(13))) | |||||
encoder.eval() | |||||
encoder.index_datasets(ds, field_name='words', add_endoftext=False) | |||||
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) | |||||
self.assertEqual(idx_list1, ds[0]['word_pieces']) | |||||
self.assertEqual(idx_list2, ds[1]['word_pieces']) | |||||
word_pieces_res = encoder(word_pieces) | |||||
self.assertEqual((torch.cat(trans_hidden_states, dim=-1)-word_pieces_res).sum(), 0) | |||||
@unittest.skipIf(True, "Only for local usage") | |||||
def test_generate_small_gpt2(self): | |||||
# 因为GPT2使用的是GPT2的tokenizer,所以没办法直接生成权重,需要用点下面的方式 | |||||
weight_path = '' | |||||
tokenizer = GPT2Tokenizer.from_pretrained(weight_path) | |||||
used_pairs = {} | |||||
used_vocab = {} | |||||
# 修改这里即可获得更多的sentence的数据 | |||||
sent1 = "This is a demo sentence" | |||||
sent2 = "another demo" | |||||
sent3 = 'this is a texta model vocab' | |||||
all_tokens = [] | |||||
for sent in [sent1, sent2, sent3]: | |||||
tokens = [] | |||||
for word in sent.split(): | |||||
word = ' '+ word | |||||
token = "".join( | |||||
tokenizer.byte_encoder[b] for b in word.encode("utf-8") | |||||
) | |||||
_token, _used_pairs = tokenizer.get_used_merge_pair_vocab(token) | |||||
tokens.extend(_token.split()) | |||||
used_pairs.update(_used_pairs) | |||||
all_tokens.extend(tokens) | |||||
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |||||
used_vocab.update({t:i for t,i in zip(tokens, token_ids)}) | |||||
print(used_pairs) | |||||
import json | |||||
with open('test/data_for_tests/embedding/small_gpt2/vocab.json', 'w') as f: | |||||
new_used_vocab = {} | |||||
for idx, key in enumerate(used_vocab.keys()): | |||||
new_used_vocab[key] = len(new_used_vocab) | |||||
new_used_vocab['<|endoftext|>'] = len(new_used_vocab) | |||||
for i in range(65, 91): | |||||
if chr(i) not in new_used_vocab: | |||||
new_used_vocab[chr(i)] = len(new_used_vocab) | |||||
for i in range(97, 123): | |||||
if chr(i) not in new_used_vocab: | |||||
new_used_vocab[chr(i)] = len(new_used_vocab) | |||||
json.dump(new_used_vocab, f) | |||||
with open('test/data_for_tests/embedding/small_gpt2/merges.txt', 'w') as f: | |||||
f.write('#version: small\n') | |||||
for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])): | |||||
f.write('{} {}\n'.format(k[0], k[1])) | |||||
new_tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2') | |||||
new_all_tokens = [] | |||||
for sent in [sent1, sent2, sent3]: | |||||
tokens = new_tokenizer.tokenize(sent, add_prefix_space=True) | |||||
new_all_tokens.extend(tokens) | |||||
print(all_tokens, new_all_tokens) | |||||
self.assertSequenceEqual(all_tokens, new_all_tokens) | |||||
config = { | |||||
"architectures": [ | |||||
"GPT2LMHeadModel" | |||||
], | |||||
"initializer_range": 0.02, | |||||
"layer_norm_epsilon": 1e-05, | |||||
"n_ctx": 20, | |||||
"n_embd": 16, | |||||
"n_head": 4, | |||||
"n_layer": 2, | |||||
"n_positions": 20, | |||||
"vocab_size": len(new_used_vocab) | |||||
} | |||||
with open('test/data_for_tests/embedding/small_gpt2/config.json', 'w') as f: | |||||
json.dump(config, f) | |||||
# 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现 | |||||
from fastNLP.modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Config | |||||
config = GPT2Config.from_pretrained('test/data_for_tests/embedding/small_gpt2') | |||||
model = GPT2LMHeadModel(config) | |||||
torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin') | |||||
print(model(torch.LongTensor([[0,1,2,3]]))) | |||||
def test_gpt2_word_piece_encoder(self): | |||||
# 主要检查可以运行 | |||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
ds = DataSet({'words': ["this is a test sentence".split()]}) | |||||
embed = GPT2WordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1) | |||||
embed.index_datasets(ds, field_name='words') | |||||
self.assertTrue(ds.has_field('word_pieces')) | |||||
result = embed(torch.LongTensor([[1, 2, 3, 4]])) | |||||
embed = GPT2WordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1, | |||||
language_model=True) | |||||
embed.index_datasets(ds, field_name='words') | |||||
self.assertTrue(ds.has_field('word_pieces')) | |||||
result = embed(torch.LongTensor([[1, 2, 3, 4]])) | |||||
def test_generate(self): | |||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) | |||||
# 测试一下各项东西是否正常work | |||||
print(encoder.generate_from_str('this', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0)) | |||||
print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0)) | |||||
print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, | |||||
repetition_penalty=2.0, length_penalty=1.5)) |
@@ -0,0 +1,252 @@ | |||||
import unittest | |||||
import torch | |||||
import os | |||||
from fastNLP import DataSet, Vocabulary | |||||
from fastNLP.embeddings.roberta_embedding import RobertaWordPieceEncoder, RobertaEmbedding | |||||
class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
def test_download(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name='en') | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
print(embed(words).size()) | |||||
for pool_method in ['first', 'last', 'max', 'avg']: | |||||
for include_cls_sep in [True, False]: | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method, | |||||
include_cls_sep=include_cls_sep) | |||||
print(embed(words).size()) | |||||
def test_robert_word_piece_encoder(self): | |||||
# 可正常运行即可 | |||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1) | |||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||||
encoder.index_datasets(ds, field_name='words') | |||||
self.assertTrue(ds.has_field('word_pieces')) | |||||
result = encoder(torch.LongTensor([[1,2,3,4]])) | |||||
def test_roberta_embed_eq_roberta_piece_encoder(self): | |||||
# 主要检查一下embedding的结果与wordpieceencoder的结果是否一致 | |||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]}) | |||||
encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path) | |||||
encoder.eval() | |||||
encoder.index_datasets(ds, field_name='words') | |||||
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) | |||||
word_pieces_res = encoder(word_pieces) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(ds, field_name='words') | |||||
vocab.index_dataset(ds, field_name='words', new_field_name='words') | |||||
ds.set_input('words') | |||||
words = torch.LongTensor(ds['words'].get([0, 1])) | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, | |||||
pool_method='first', include_cls_sep=True, pooled_cls=False) | |||||
embed.eval() | |||||
words_res = embed(words) | |||||
# 检查word piece什么的是正常work的 | |||||
self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) | |||||
self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) | |||||
self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) | |||||
@unittest.skipIf(True, "Only for local debugging") | |||||
def test_eq_transformers(self): | |||||
weight_path = '' | |||||
ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]}) | |||||
encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path) | |||||
encoder.eval() | |||||
encoder.index_datasets(ds, field_name='words') | |||||
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) | |||||
word_pieces_res = encoder(word_pieces) | |||||
import transformers | |||||
input1 = ' '.join(ds[0]['words']) | |||||
input2 = ' '.join(ds[1]['words']) | |||||
tokenizer = transformers.RobertaTokenizer.from_pretrained(weight_path) | |||||
idx_list1 = tokenizer.encode(input1) | |||||
idx_list2 = tokenizer.encode(input2) | |||||
self.assertEqual(idx_list1, ds[0]['word_pieces']) | |||||
self.assertEqual(idx_list2, ds[1]['word_pieces']) | |||||
pad_value = tokenizer.encode('<pad>')[0] | |||||
tensor = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(idx_list1), | |||||
torch.LongTensor(idx_list2)], | |||||
batch_first=True, | |||||
padding_value=pad_value) | |||||
roberta = transformers.RobertaModel.from_pretrained(weight_path, output_hidden_states=True) | |||||
roberta.eval() | |||||
output, pooled_output, hidden_states = roberta(tensor, attention_mask=tensor.ne(pad_value)) | |||||
self.assertEqual((output-word_pieces_res).sum(), 0) | |||||
@unittest.skipIf(True, "Only for local usage") | |||||
def test_generate_small_roberta(self): | |||||
""" | |||||
因为Roberta使用的是GPT2的tokenizer,所以没办法直接生成权重,需要用点下面的方式 | |||||
:return: | |||||
""" | |||||
weight_path = '' | |||||
from fastNLP.modules.tokenizer import RobertaTokenizer | |||||
tokenizer = RobertaTokenizer.from_pretrained(weight_path) | |||||
used_pairs = {} | |||||
used_vocab = {} | |||||
# 修改这里即可获得更多的sentence的数据 | |||||
sent1 = "This is a demo sentence" | |||||
sent2 = "another demo" | |||||
sent3 = 'this is a texta model vocab' | |||||
all_tokens = [] | |||||
for sent in [sent1, sent2, sent3]: | |||||
tokens = [] | |||||
for word in sent.split(): | |||||
word = ' '+ word | |||||
token = "".join( | |||||
tokenizer.byte_encoder[b] for b in word.encode("utf-8") | |||||
) | |||||
_token, _used_pairs = tokenizer.get_used_merge_pair_vocab(token) | |||||
tokens.extend(_token.split()) | |||||
used_pairs.update(_used_pairs) | |||||
all_tokens.extend(tokens) | |||||
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |||||
used_vocab.update({t:i for t,i in zip(tokens, token_ids)}) | |||||
import json | |||||
with open('test/data_for_tests/embedding/small_roberta/vocab.json', 'w') as f: | |||||
new_used_vocab = {} | |||||
for token in ['<s>', '<pad>', '</s>', '<unk>', '<mask>']: # <pad>必须为1 | |||||
new_used_vocab[token] = len(new_used_vocab) | |||||
for i in range(65, 91): | |||||
if chr(i) not in new_used_vocab: | |||||
new_used_vocab[chr(i)] = len(new_used_vocab) | |||||
for i in range(97, 123): | |||||
if chr(i) not in new_used_vocab: | |||||
new_used_vocab[chr(i)] = len(new_used_vocab) | |||||
for idx, key in enumerate(used_vocab.keys()): | |||||
if key not in new_used_vocab: | |||||
new_used_vocab[key] = len(new_used_vocab) | |||||
json.dump(new_used_vocab, f) | |||||
with open('test/data_for_tests/embedding/small_roberta/merges.txt', 'w') as f: | |||||
f.write('#version: tiny\n') | |||||
for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])): | |||||
f.write('{} {}\n'.format(k[0], k[1])) | |||||
config = { | |||||
"architectures": [ | |||||
"RobertaForMaskedLM" | |||||
], | |||||
"attention_probs_dropout_prob": 0.1, | |||||
"finetuning_task": None, | |||||
"hidden_act": "gelu", | |||||
"hidden_dropout_prob": 0.1, | |||||
"hidden_size": 16, | |||||
"initializer_range": 0.02, | |||||
"intermediate_size": 20, | |||||
"layer_norm_eps": 1e-05, | |||||
"max_position_embeddings": 20, | |||||
"num_attention_heads": 4, | |||||
"num_hidden_layers": 2, | |||||
"num_labels": 2, | |||||
"output_attentions": False, | |||||
"output_hidden_states": False, | |||||
"torchscript": False, | |||||
"type_vocab_size": 1, | |||||
"vocab_size": len(new_used_vocab) | |||||
} | |||||
with open('test/data_for_tests/embedding/small_roberta/config.json', 'w') as f: | |||||
json.dump(config, f) | |||||
new_tokenizer = RobertaTokenizer.from_pretrained('test/data_for_tests/embedding/small_roberta') | |||||
new_all_tokens = [] | |||||
for sent in [sent1, sent2, sent3]: | |||||
tokens = new_tokenizer.tokenize(sent, add_prefix_space=True) | |||||
new_all_tokens.extend(tokens) | |||||
print(all_tokens, new_all_tokens) | |||||
self.assertSequenceEqual(all_tokens, new_all_tokens) | |||||
# 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现 | |||||
from fastNLP.modules.encoder.roberta import RobertaModel, BertConfig | |||||
config = BertConfig.from_json_file('test/data_for_tests/embedding/small_roberta/config.json') | |||||
model = RobertaModel(config) | |||||
torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') | |||||
print(model(torch.LongTensor([[0,1,2,3]]))) | |||||
class TestRobertaEmbedding(unittest.TestCase): | |||||
def test_roberta_embedding_1(self): | |||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInRoberta".split()) | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1) | |||||
requires_grad = embed.requires_grad | |||||
embed.requires_grad = not requires_grad | |||||
embed.train() | |||||
words = torch.LongTensor([[2, 3, 4, 1]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1, | |||||
only_use_pretrain_bpe=True) | |||||
embed.eval() | |||||
words = torch.LongTensor([[2, 3, 4, 1]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
# 自动截断而不报错 | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1, | |||||
only_use_pretrain_bpe=True, auto_truncate=True) | |||||
words = torch.LongTensor([[2, 3, 4, 1]*10, | |||||
[2, 3]+[0]*38]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (2, 40, 16)) | |||||
def test_roberta_ebembedding_2(self): | |||||
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | |||||
Embedding = RobertaEmbedding | |||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
vocab = Vocabulary().add_word_lst("this is a texta and".split()) | |||||
embed1 = Embedding(vocab, model_dir_or_name=weight_path,layers=list(range(3)), | |||||
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | |||||
# embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) | |||||
embed2 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | |||||
only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) | |||||
# embed_bpe_vocab_size = num_word # 排除NotInBERT | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) | |||||
embed3 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | |||||
only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) | |||||
# embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) | |||||
embed4 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | |||||
only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) | |||||
# embed_bpe_vocab_size = num_word+1 # 新增##a | |||||
# self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) | |||||
# 测试各种情况下以下tensor的值是相等的 | |||||
embed1.eval() | |||||
embed2.eval() | |||||
embed3.eval() | |||||
embed4.eval() | |||||
tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) | |||||
t1 = embed1(tensor) | |||||
t2 = embed2(tensor) | |||||
t3 = embed3(tensor) | |||||
t4 = embed4(tensor) | |||||
self.assertEqual((t1-t2).sum(), 0) | |||||
self.assertEqual((t1-t3).sum(), 0) | |||||
self.assertEqual((t1-t4).sum(), 0) |
@@ -0,0 +1,19 @@ | |||||
import unittest | |||||
from fastNLP.modules.tokenizer import BertTokenizer | |||||
class TestBertTokenizer(unittest.TestCase): | |||||
def test_run(self): | |||||
# 测试支持的两种encode方式 | |||||
tokenizer = BertTokenizer.from_pretrained('test/data_for_tests/embedding/small_bert') | |||||
tokens1 = tokenizer.encode("This is a demo") | |||||
tokens2 = tokenizer.encode("This is a demo", add_special_tokens=False) | |||||
tokens3 = tokenizer.encode("This is a demo".split()) | |||||
tokens4 = tokenizer.encode("This is a demo".split(), add_special_tokens=False) | |||||
self.assertEqual(len(tokens1)-2, len(tokens2)) | |||||
self.assertEqual(len(tokens3)-2, len(tokens4)) | |||||
self.assertEqual(tokens1[0], tokenizer.cls_index) | |||||
self.assertEqual(tokens1[-1], tokenizer.sep_index) |