@@ -2,6 +2,6 @@ fastNLP.core.callback | |||
===================== | |||
.. automodule:: fastNLP.core.callback | |||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError | |||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError | |||
:inherited-members: | |||
@@ -2,5 +2,5 @@ fastNLP.modules.encoder | |||
======================= | |||
.. automodule:: fastNLP.modules.encoder | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, BiAttention, SelfAttention | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, BiAttention, SelfAttention | |||
@@ -2,7 +2,7 @@ fastNLP.modules | |||
=============== | |||
.. automodule:: fastNLP.modules | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout | |||
子模块 | |||
------ | |||
@@ -2,7 +2,7 @@ fastNLP | |||
======= | |||
.. automodule:: fastNLP | |||
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger | |||
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger | |||
:inherited-members: | |||
子模块 | |||
@@ -36,8 +36,6 @@ __all__ = [ | |||
"TensorboardCallback", | |||
"WarmupCallback", | |||
'SaveModelCallback', | |||
"EchoCallback", | |||
"TesterCallback", | |||
"CallbackException", | |||
"EarlyStopError", | |||
@@ -48,8 +48,6 @@ __all__ = [ | |||
"TensorboardCallback", | |||
"WarmupCallback", | |||
'SaveModelCallback', | |||
"EchoCallback", | |||
"TesterCallback", | |||
"CallbackException", | |||
"EarlyStopError", | |||
@@ -78,8 +76,8 @@ __all__ = [ | |||
from ._logger import logger | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | |||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | |||
TesterCallback, CallbackException, EarlyStopError | |||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ | |||
EarlyStopError | |||
from .const import Const | |||
from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
@@ -193,13 +193,14 @@ class DataSetIter(BatchIter): | |||
Default: ``None`` | |||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
Default: ``False`` | |||
:param int num_workers: 使用多少个进程来预处理数据 | |||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||
:param timeout: | |||
:param timeout: 生成一个batch的timeout值 | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
:param collate_fn: 用于将样本组合成batch的函数 | |||
""" | |||
assert isinstance(dataset, DataSet) | |||
dataset = DataSetGetter(dataset, as_numpy) | |||
@@ -220,12 +221,26 @@ class DataSetIter(BatchIter): | |||
class TorchLoaderIter(BatchIter): | |||
""" | |||
与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||
与DataSetIter类似,但用于pytorch的DataSet对象。 | |||
通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None): | |||
""" | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||
Default: ``None`` | |||
:param int num_workers: 使用多少个进程来预处理数据 | |||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||
:param timeout: 生成一个batch的timeout值 | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
:param collate_fn: 用于将样本组合成batch的函数""" | |||
assert len(dataset) > 0 | |||
ins = dataset[0] | |||
assert len(ins) == 2 and \ | |||
@@ -62,8 +62,6 @@ __all__ = [ | |||
"TensorboardCallback", | |||
"WarmupCallback", | |||
"SaveModelCallback", | |||
"EchoCallback", | |||
"TesterCallback", | |||
"CallbackException", | |||
"EarlyStopError" | |||
@@ -87,12 +85,18 @@ except: | |||
from .dataset import DataSet | |||
from .tester import Tester | |||
from ._logger import logger | |||
from .utils import _check_fp16 | |||
try: | |||
import fitlog | |||
except: | |||
pass | |||
try: | |||
from apex import amp | |||
except: | |||
amp = None | |||
class Callback(object): | |||
""" | |||
@@ -269,14 +273,6 @@ class Callback(object): | |||
:return: | |||
""" | |||
pass | |||
def on_validation(self): | |||
""" | |||
如果Trainer中设置了验证,则会在每次需要验证时调用该函数 | |||
:return: | |||
""" | |||
pass | |||
def on_epoch_end(self): | |||
""" | |||
@@ -470,7 +466,7 @@ class GradientClipCallback(Callback): | |||
if self.step%self.update_every==0: | |||
if self.parameters is None: | |||
if getattr(self.trainer, 'fp16', ''): | |||
from apex import amp | |||
_check_fp16() | |||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||
self.clip_fun(self.model.parameters(), self.clip_value) | |||
else: | |||
@@ -713,6 +709,8 @@ class ControlC(Callback): | |||
class SmoothValue(object): | |||
"""work for LRFinder""" | |||
def __init__(self, beta: float): | |||
self.beta, self.n, self.mov_avg = beta, 0, 0 | |||
self.smooth = None | |||
@@ -1025,6 +1023,10 @@ class EarlyStopError(CallbackException): | |||
class EchoCallback(Callback): | |||
""" | |||
用于测试分布式训练 | |||
""" | |||
def __init__(self, name, out=sys.stdout): | |||
super(EchoCallback, self).__init__() | |||
self.name = name | |||
@@ -1036,27 +1038,23 @@ class EchoCallback(Callback): | |||
return super(EchoCallback, self).__getattribute__(item) | |||
class TesterCallback(Callback): | |||
class _TesterCallback(Callback): | |||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): | |||
super(TesterCallback, self).__init__() | |||
super(_TesterCallback, self).__init__() | |||
if hasattr(model, 'module'): | |||
# for data parallel model | |||
model = model.module | |||
self.tester = Tester(data, model, | |||
metrics=metrics, batch_size=batch_size, | |||
num_workers=num_workers, verbose=0) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
self.increase_better = True | |||
if metric_key is not None: | |||
self.increase_better = False if metric_key[0] == "-" else True | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
self.metric_key, self.increase_better = self._parse_metric_key(metric_key) | |||
else: | |||
self.metric_key = None | |||
self.increase_better = True | |||
self.score = None | |||
def on_validation(self): | |||
def on_valid_begin(self): | |||
cur_score = self.tester.test() | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | |||
self.epoch, self.n_epochs, self.step, self.n_steps, | |||
@@ -1067,17 +1065,28 @@ class TesterCallback(Callback): | |||
self.score = cur_score | |||
return cur_score, is_better | |||
def _get_score(self, metric_dict, key): | |||
@staticmethod | |||
def _get_score(metric_dict, key): | |||
for metric in metric_dict.items(): | |||
if key in metric: | |||
return metric[key] | |||
return None | |||
@staticmethod | |||
def _parse_metric_key(metric_key): | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
increase_better = False if metric_key[0] == "-" else True | |||
metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
return metric_key, increase_better | |||
def compare_better(self, a): | |||
if self.score is None: | |||
return True | |||
if self.metric_key is None: | |||
self.metric_key = list(list(self.score.values())[0].keys())[0] | |||
metric_key = list(list(self.score.values())[0].keys())[0] | |||
self.metric_key, self.increase_better = self._parse_metric_key(metric_key) | |||
k = self.metric_key | |||
score = self._get_score(self.score, k) | |||
new_score = self._get_score(a, k) | |||
@@ -1087,7 +1096,3 @@ class TesterCallback(Callback): | |||
return score <= new_score | |||
else: | |||
return score >= new_score | |||
def on_train_end(self): | |||
self.logger.info('Evaluate on training ends.') | |||
self.on_validation() |
@@ -17,21 +17,31 @@ from tqdm import tqdm | |||
from ._logger import logger | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||
from .callback import DistCallbackManager, CallbackException | |||
from .callback import _TesterCallback | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
from .optimizer import Optimizer | |||
from .utils import _build_args | |||
from .utils import _get_func_signature | |||
from .utils import _move_dict_value_to_device | |||
from .utils import _check_fp16 | |||
try: | |||
from apex import amp | |||
except: | |||
amp = None | |||
__all__ = [ | |||
'get_local_rank', | |||
'DistTrainer', | |||
] | |||
def get_local_rank(): | |||
""" | |||
返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数 | |||
""" | |||
if 'LOCAL_RANK' in os.environ: | |||
return int(os.environ['LOCAL_RANK']) | |||
from argparse import ArgumentParser | |||
@@ -46,7 +56,10 @@ def get_local_rank(): | |||
class DistTrainer(): | |||
""" | |||
Distributed Trainer that support distributed and mixed precision training | |||
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 | |||
Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, | |||
请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等) | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
callbacks_all=None, callbacks_master=None, | |||
@@ -55,8 +68,43 @@ class DistTrainer(): | |||
dev_data=None, metrics=None, metric_key=None, | |||
update_every=1, print_every=10, validate_every=-1, | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None): | |||
fp16='', backend=None, init_method=None, use_tqdm=True): | |||
""" | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param nn.modules model: 待训练的模型 | |||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
:param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 | |||
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||
:param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。 | |||
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||
:param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。 | |||
:param int n_epochs: 需要优化迭代多少次。 | |||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 | |||
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , | |||
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 | |||
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 | |||
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, | |||
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 | |||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
:param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。 | |||
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | |||
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param str device: 指定 device,可以是 gpu,cpu 或 auto | |||
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 | |||
:param backend: 指定分布式的backend,详情参考 pytorch 文档 | |||
:param init_method 指定分布式的初始化方法,详情参考 pytorch 文档 | |||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
""" | |||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | |||
if device == 'auto': | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
@@ -94,7 +142,9 @@ class DistTrainer(): | |||
self.callback_manager = DistCallbackManager( | |||
env={"trainer": self}, callbacks_all=callbacks_all, | |||
callbacks_master=callbacks_master) | |||
self.test_manager = DistCallbackManager(env={'trainer': self}) | |||
self.metric_key = metric_key | |||
self.use_tqdm = use_tqdm | |||
model.to(self.device) | |||
optimizer = self._get_optimizer(optimizer) | |||
@@ -102,11 +152,7 @@ class DistTrainer(): | |||
# init fp16, must before DataParallel init | |||
if len(self.fp16): | |||
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" | |||
try: | |||
from apex import amp | |||
except ImportError: | |||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | |||
_check_fp16() | |||
assert device == 'cuda', "Amp requires cuda device" | |||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | |||
@@ -121,20 +167,21 @@ class DistTrainer(): | |||
self.optimizer = optimizer | |||
self.sampler = DistributedSampler(self.train_data) | |||
self.data_iterator = self._get_data_iter(self.train_data) | |||
self.batch_size = self.world_size * self.batch_size_per_gpu | |||
self.n_steps = self._get_n_steps() | |||
# for evaluation, only run eval on master proc | |||
if dev_data and metrics: | |||
cb = TesterCallback( | |||
cb = _TesterCallback( | |||
dev_data, model, metrics, | |||
batch_size=batch_size_per_gpu, num_workers=num_workers) | |||
self.callback_manager.add_callback([cb], master=True) | |||
self.test_manager.add_callback([cb], master=False) | |||
# Setup logging | |||
dist.barrier() | |||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||
if self.save_path: | |||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints') | |||
else: | |||
self.cp_save_path = None | |||
@@ -178,9 +225,27 @@ class DistTrainer(): | |||
@property | |||
def is_master(self): | |||
"""是否是主进程""" | |||
return self.rank == 0 | |||
def train(self, on_exception='auto'): | |||
def train(self, load_best_model=True, on_exception='auto'): | |||
""" | |||
使用该函数使Trainer开始训练。 | |||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | |||
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; | |||
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. | |||
:return dict: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称, | |||
第二层的key为具体的Metric | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
""" | |||
try: | |||
self.logger.info("###### Training epochs started ######") | |||
self.logger.info('Total epochs: %d'% self.n_epochs) | |||
@@ -222,17 +287,22 @@ class DistTrainer(): | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
self.logger.info("###### Train finished ######") | |||
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | |||
return results | |||
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): | |||
self.load_check_point('best') | |||
finally: | |||
self.close() | |||
pass | |||
dist.barrier() | |||
return results | |||
def _train(self): | |||
if self.fp16: | |||
# skip check, done in __init__() | |||
from apex import amp | |||
if not self.use_tqdm: | |||
from .utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
self.epoch = 0 | |||
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||
self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||
leave=False, dynamic_ncols=True, disable=not self.is_master) | |||
pbar = self.pbar | |||
avg_loss = 0 | |||
@@ -292,8 +362,8 @@ class DistTrainer(): | |||
if self.validate_every < 0: | |||
self._do_validation() | |||
if self.save_every < 0 and self.cp_save_path: | |||
self.save_check_point() | |||
if self.save_every < 0 and self.cp_save_path: | |||
self.save_check_point() | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
@@ -327,30 +397,52 @@ class DistTrainer(): | |||
loss = self.losser(predict, truth) | |||
if self.update_every > 1: | |||
loss = loss / self.update_every | |||
return loss.mean() | |||
if loss.dim() > 0: | |||
loss = loss.mean() | |||
return loss | |||
def save_check_point(self, only_params=False): | |||
def save_check_point(self, name=None, only_params=False): | |||
"""保存当前模型""" | |||
# only master save models | |||
if self.is_master: | |||
if name is None: | |||
name = 'checkpoint-{}.bin'.format(self.step) | |||
os.makedirs(self.cp_save_path, exist_ok=True) | |||
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||
path = os.path.join(self.cp_save_path, name) | |||
self.logger.info("Save checkpoint to {}".format(path)) | |||
model_to_save = self.model.module | |||
if only_params: | |||
model_to_save = model_to_save.state_dict() | |||
torch.save(model_to_save, path) | |||
def load_check_point(self, name): | |||
path = os.path.join(self.cp_save_path, name) | |||
self.logger.info('reload best model from %s', path) | |||
model_load = torch.load(path, map_location='cpu') | |||
if not isinstance(model_load, dict): | |||
model_load = model_load.state_dict() | |||
self.model.module.load_state_dict(model_load) | |||
def _do_validation(self): | |||
self.callback_manager.on_valid_begin() | |||
eval_res = self.callback_manager.on_validation() | |||
# do evaluate on all nodes | |||
eval_res = self.test_manager.on_valid_begin() | |||
eval_res = list(filter(lambda x: x is not None, eval_res)) | |||
if len(eval_res): | |||
eval_res, is_better = list(zip(*eval_res)) | |||
else: | |||
eval_res, is_better = None, None | |||
# save better model on master node | |||
if self.is_master and is_better is not None and self.cp_save_path: | |||
for i, better_flag in enumerate(is_better): | |||
if better_flag: | |||
# TODO to support multiple datasets to evaluate | |||
self.save_check_point('best') | |||
break | |||
self.callback_manager.on_valid_end( | |||
eval_res, self.metric_key, self.optimizer, is_better) | |||
dist.barrier() | |||
def close(self): | |||
"""关闭Trainer,销毁进程""" | |||
dist.destroy_process_group() |
@@ -829,12 +829,12 @@ class Trainer(object): | |||
self.best_metric_indicator = indicator_val | |||
else: | |||
if self.increase_better is True: | |||
if indicator_val >= self.best_metric_indicator: | |||
if indicator_val > self.best_metric_indicator: | |||
self.best_metric_indicator = indicator_val | |||
else: | |||
is_better = False | |||
else: | |||
if indicator_val <= self.best_metric_indicator: | |||
if indicator_val < self.best_metric_indicator: | |||
self.best_metric_indicator = indicator_val | |||
else: | |||
is_better = False | |||
@@ -842,6 +842,7 @@ class Trainer(object): | |||
@property | |||
def is_master(self): | |||
"""是否是主进程""" | |||
return True | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
@@ -19,6 +19,10 @@ import torch.nn as nn | |||
from typing import List | |||
from ._logger import logger | |||
from prettytable import PrettyTable | |||
try: | |||
from apex import amp | |||
except: | |||
amp = None | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
@@ -805,3 +809,10 @@ def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
if len(string) > avg: | |||
string = string[:(avg - 3)] + "..." | |||
return string | |||
def _check_fp16(): | |||
if amp is None: | |||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||
if not torch.backends.cudnn.enabled: | |||
raise RuntimeError("Amp requires cudnn backend to be enabled.") |
@@ -1,4 +1,6 @@ | |||
"""undocumented""" | |||
"""undocumented | |||
用于辅助生成 fastNLP 文档的代码 | |||
""" | |||
__all__ = [] | |||
@@ -15,6 +17,9 @@ def doc_process(m): | |||
pass | |||
else: | |||
module_name = obj.__module__ | |||
# 识别并标注类和函数在不同层次中的位置 | |||
while 1: | |||
defined_m = sys.modules[module_name] | |||
if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: | |||
@@ -25,6 +30,8 @@ def doc_process(m): | |||
if module_name == m.__name__: | |||
# print(name, ": not found defined doc.") | |||
break | |||
# 识别并标注基类,只有基类也在 fastNLP 中定义才显示 | |||
if inspect.isclass(obj): | |||
for base in obj.__bases__: | |||
@@ -25,6 +25,8 @@ __all__ = [ | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
"THUCNewsLoader", | |||
"WeiboSenti100kLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
@@ -45,6 +47,9 @@ __all__ = [ | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"XNLILoader", | |||
"BQCorpusLoader", | |||
"LCQMCLoader", | |||
"Pipe", | |||
@@ -54,6 +59,8 @@ __all__ = [ | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"THUCNewsPipe", | |||
"WeiboSenti100kPipe", | |||
"Conll2003Pipe", | |||
"Conll2003NERPipe", | |||
@@ -83,25 +83,41 @@ PRETRAIN_STATIC_FILES = { | |||
} | |||
DATASET_DIR = { | |||
# Classification, English | |||
'aclImdb': "imdb.zip", | |||
"yelp-review-full": "yelp_review_full.tar.gz", | |||
"yelp-review-polarity": "yelp_review_polarity.tar.gz", | |||
"sst-2": "SST-2.zip", | |||
"sst": "SST.zip", | |||
# Classification, Chinese | |||
"chn-senti-corp": "chn_senti_corp.zip", | |||
"weibo-senti-100k": "WeiboSenti100k.zip", | |||
"thuc-news": "THUCNews.zip", | |||
# Matching, English | |||
"mnli": "MNLI.zip", | |||
"snli": "SNLI.zip", | |||
"qnli": "QNLI.zip", | |||
"sst-2": "SST-2.zip", | |||
"sst": "SST.zip", | |||
"rte": "RTE.zip", | |||
# Matching, Chinese | |||
"cn-xnli": "XNLI.zip", | |||
# Sequence Labeling, Chinese | |||
"msra-ner": "MSRA_NER.zip", | |||
"peopledaily": "peopledaily.zip", | |||
"weibo-ner": "weibo_NER.zip", | |||
# Chinese Word Segmentation | |||
"cws-pku": 'cws_pku.zip', | |||
"cws-cityu": "cws_cityu.zip", | |||
"cws-as": 'cws_as.zip', | |||
"cws-msra": 'cws_msra.zip', | |||
"chn-senti-corp":"chn_senti_corp.zip" | |||
# Summarization, English | |||
"ext-cnndm": "ext-cnndm.zip", | |||
} | |||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||
@@ -373,63 +373,6 @@ class ChnSentiCorpLoader(Loader): | |||
""" | |||
从path中读取数据 | |||
:param path: | |||
:return: | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() | |||
for line in f: | |||
line = line.strip() | |||
tab_index = line.index('\t') | |||
if tab_index!=-1: | |||
target = line[:tab_index] | |||
raw_chars = line[tab_index+1:] | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self)->str: | |||
""" | |||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | |||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('chn-senti-corp') | |||
return output_dir | |||
class ChnSentiCorpLoader(Loader): | |||
""" | |||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | |||
一个制表符及之后认为是句子 | |||
Example:: | |||
label raw_chars | |||
1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~ | |||
1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道... | |||
0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货... | |||
读取后的DataSet具有以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||
"..." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
""" | |||
从path中读取数据 | |||
:param path: | |||
:return: | |||
""" | |||
@@ -441,7 +384,7 @@ class ChnSentiCorpLoader(Loader): | |||
tab_index = line.index('\t') | |||
if tab_index != -1: | |||
target = line[:tab_index] | |||
raw_chars = line[tab_index + 1:] | |||
raw_chars = line[tab_index+1:] | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
@@ -486,6 +429,17 @@ class THUCNewsLoader(Loader): | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self) -> str: | |||
""" | |||
自动下载数据,该数据取自 | |||
http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('thuc-news') | |||
return output_dir | |||
class WeiboSenti100kLoader(Loader): | |||
""" | |||
@@ -518,3 +472,12 @@ class WeiboSenti100kLoader(Loader): | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self) -> str: | |||
""" | |||
自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ | |||
在 https://arxiv.org/abs/1906.08101 有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('weibo-senti-100k') | |||
return output_dir |
@@ -316,6 +316,16 @@ class CTBLoader(Loader): | |||
dataset = self.loader._load(path) | |||
return dataset | |||
def download(self): | |||
""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://catalog.ldc.upenn.edu/LDC2013T21 | |||
:return: | |||
""" | |||
raise RuntimeError("CTB cannot be downloaded automatically.") | |||
class CNNERLoader(Loader): | |||
def _load(self, path: str): | |||
@@ -13,23 +13,21 @@ from .json import JsonLoader | |||
class CoReferenceLoader(JsonLoader): | |||
""" | |||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||
Example:: | |||
Example:: | |||
{"doc_key":"bc/cctv/00/cctv_001", | |||
"speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]", | |||
"clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]", | |||
"sentences":[["I","have","an","apple"],["It","is","good"]] | |||
} | |||
{"doc_key":"bc/cctv/00/cctv_001", | |||
"speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]", | |||
"clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]", | |||
"sentences":[["I","have","an","apple"],["It","is","good"]] | |||
} | |||
读取预处理好的Conll2012数据。 | |||
读取预处理好的Conll2012数据。 | |||
""" | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super().__init__(fields, dropna) | |||
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1), | |||
# "clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | |||
"sentences": Const.RAW_WORDS(3)} | |||
@@ -48,3 +46,13 @@ class CoReferenceLoader(JsonLoader): | |||
ins = d | |||
dataset.append(Instance(**ins)) | |||
return dataset | |||
def download(self): | |||
""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://www.aclweb.org/anthology/W12-4501 | |||
:return: | |||
""" | |||
raise RuntimeError("CoReference cannot be downloaded automatically.") |
@@ -7,7 +7,7 @@ __all__ = [ | |||
"RTELoader", | |||
"QuoraLoader", | |||
"BQCorpusLoader", | |||
"XNLILoader", | |||
"CNXNLILoader", | |||
"LCQMCLoader" | |||
] | |||
@@ -135,12 +135,12 @@ class SNLILoader(JsonLoader): | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||
读取的field根据Loader初始化时传入的field决定。 | |||
:param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl | |||
和snli_1.0_test.jsonl三个文件。 | |||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||
:return: 返回的 :class:`~fastNLP.io.DataBundle` | |||
""" | |||
_paths = {} | |||
if paths is None: | |||
@@ -222,8 +222,7 @@ class QNLILoader(JsonLoader): | |||
""" | |||
如果您的实验使用到了该数据,请引用 | |||
.. todo:: | |||
补充 | |||
https://arxiv.org/pdf/1809.05053.pdf | |||
:return: | |||
""" | |||
@@ -276,6 +275,13 @@ class RTELoader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
如果您的实验使用到了该数据,请引用GLUE Benchmark | |||
https://openreview.net/pdf?id=rJ4km2R5t7 | |||
:return: | |||
""" | |||
return self._get_dataset_path('rte') | |||
@@ -321,10 +327,17 @@ class QuoraLoader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://www.kaggle.com/c/quora-question-pairs/data | |||
:return: | |||
""" | |||
raise RuntimeError("Quora cannot be downloaded automatically.") | |||
class XNLILoader(Loader): | |||
class CNXNLILoader(Loader): | |||
""" | |||
别名: | |||
数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理 | |||
@@ -341,7 +354,7 @@ class XNLILoader(Loader): | |||
""" | |||
def __init__(self): | |||
super(XNLILoader, self).__init__() | |||
super(CNXNLILoader, self).__init__() | |||
def _load(self, path: str = None): | |||
csv_loader = CSVLoader(sep='\t') | |||
@@ -377,6 +390,16 @@ class XNLILoader(Loader): | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self) -> str: | |||
""" | |||
自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 | |||
在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf | |||
https://arxiv.org/pdf/1809.05053.pdf 有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('cn-xnli') | |||
return output_dir | |||
class BQCorpusLoader(Loader): | |||
""" | |||
@@ -413,6 +436,16 @@ class BQCorpusLoader(Loader): | |||
ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | |||
return ds | |||
def download(self): | |||
""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://github.com/ymcui/Chinese-BERT-wwm | |||
:return: | |||
""" | |||
raise RuntimeError("BQCorpus cannot be downloaded automatically.") | |||
class LCQMCLoader(Loader): | |||
""" | |||
@@ -451,16 +484,14 @@ class LCQMCLoader(Loader): | |||
ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | |||
return ds | |||
''' | |||
def download(self)->str: | |||
def download(self): | |||
""" | |||
自动下载数据,该数据取自论文 LCQMC: A Large-scale Chinese Question Matching Corpus. | |||
InProceedings of the 27thInternational Conference on Computational Linguistics. 1952–1962. | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://github.com/ymcui/Chinese-BERT-wwm | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('chn-senti-corp') | |||
return output_dir | |||
''' | |||
raise RuntimeError("LCQMC cannot be downloaded automatically.") | |||
@@ -0,0 +1,63 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"ExtCNNDMLoader" | |||
] | |||
import os | |||
from typing import Union, Dict | |||
from ..data_bundle import DataBundle | |||
from ..utils import check_loader_paths | |||
from .json import JsonLoader | |||
class ExtCNNDMLoader(JsonLoader): | |||
""" | |||
读取之后的DataSet中的field情况为 | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication" | |||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||
["..."], ["..."], [], "cnndm" | |||
""" | |||
def __init__(self, fields=None): | |||
fields = fields or {"text": None, "summary": None, "label": None, "publication": None} | |||
super(ExtCNNDMLoader, self).__init__(fields=fields) | |||
def load(self, paths: Union[str, Dict[str, str]] = None): | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 | |||
:param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl | |||
test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe` | |||
当中需要用到)。 | |||
:return: 返回 :class:`~fastNLP.io.DataBundle` | |||
""" | |||
if paths is None: | |||
paths = self.download() | |||
paths = check_loader_paths(paths) | |||
if ('train' in paths) and ('test' not in paths): | |||
paths['test'] = paths['train'] | |||
paths.pop('train') | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
""" | |||
如果你使用了这个数据,请引用 | |||
https://arxiv.org/pdf/1506.03340.pdf | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('ext-cnndm') | |||
return output_dir |
@@ -18,6 +18,8 @@ __all__ = [ | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"THUCNewsPipe", | |||
"WeiboSenti100kPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
@@ -42,7 +44,7 @@ __all__ = [ | |||
"CoReferencePipe" | |||
] | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe | |||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
@@ -97,11 +97,22 @@ class YelpFullPipe(_CLSPipe): | |||
处理YelpFull的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||
"...", ., "[...]", . | |||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||
"...", "[...]", ., . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
@@ -193,11 +204,22 @@ class YelpPolarityPipe(_CLSPipe): | |||
处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||
"...", "[...]", ., . | |||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||
"...", ., "[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
@@ -211,6 +233,19 @@ class YelpPolarityPipe(_CLSPipe): | |||
self.lower = lower | |||
def process(self, data_bundle): | |||
""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"I got 'new' tires from them and... ", "1" | |||
"Don't waste your time. We had two...", "1" | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
# 复制一列words | |||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||
@@ -244,9 +279,20 @@ class SSTPipe(_CLSPipe): | |||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"It 's a ...", "[4, 2, 10, ...]", 0, 16 | |||
"Offers that ...", "[20, 40, ...]", 1, 18 | |||
"...", "[...]", ., . | |||
"It 's a lovely film with lovely perfor...", 1, "[187, 6, 5, 132, 120, 70, 132, 188, 25...", 13 | |||
"No one goes unindicted here , which is...", 0, "[191, 126, 192, 193, 194, 4, 195, 17, ...", 13 | |||
"...", ., "[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
@@ -278,11 +324,11 @@ class SSTPipe(_CLSPipe): | |||
""" | |||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | |||
.. csv-table:: | |||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||
:header: "raw_words" | |||
"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." | |||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | |||
"(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." | |||
"(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." | |||
"..." | |||
:param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 | |||
@@ -335,12 +381,23 @@ class SST2Pipe(_CLSPipe): | |||
加载SST2的数据, 处理完成之后DataSet将拥有以下的field | |||
.. csv-table:: | |||
:header: "raw_words", "words", "target", "seq_len" | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43 | |||
"unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21 | |||
"it 's a charming and often affecting j... ", 1, "[19, 9, 6, 111, 5, 112, 113, 114, 3]", 9 | |||
"unflinchingly bleak and desperate", 0, "[115, 116, 5, 117]", 4 | |||
"...", "...", ., . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower=False, tokenizer='spacy'): | |||
@@ -357,11 +414,11 @@ class SST2Pipe(_CLSPipe): | |||
可以处理的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
:header: "raw_words", "target" | |||
"it 's a charming and... ", 1 | |||
"unflinchingly bleak and...", 1 | |||
"...", "..." | |||
"it 's a charming and often affecting...", "1" | |||
"unflinchingly bleak and...", "0" | |||
"..." | |||
:param data_bundle: | |||
:return: | |||
@@ -420,15 +477,26 @@ class IMDBPipe(_CLSPipe): | |||
经过本Pipe处理后DataSet将如下 | |||
.. csv-table:: 输出DataSet的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20 | |||
"Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31 | |||
"...", "[...]", ., . | |||
"Bromwell High is a cartoon ... ", 0, "[3, 5, 6, 9, ...]", 20 | |||
"Story of a man who has ...", 1, "[20, 43, 9, 10, ...]", 31 | |||
"...", ., "[...]", . | |||
其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; | |||
words列被设置为input; target列被设置为target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
@@ -493,13 +561,23 @@ class ChnSentiCorpPipe(Pipe): | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31 | |||
"<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25 | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", 1, "[2, 3, 4, 5, ...]", 31 | |||
"<荐书> 推荐所有喜欢<红楼>...", 1, "[10, 21, ....]", 25 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
@@ -590,12 +668,22 @@ class THUCNewsPipe(_CLSPipe): | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", "[409, 1197, 2146, 213, ...]", 0, 746 | |||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", 0, "[409, 1197, 2146, 213, ...]", 746 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
@@ -691,12 +779,22 @@ class WeiboSenti100kPipe(_CLSPipe): | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "[0, 690, 18, ...]", 0, 56 | |||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", 0, "[0, 690, 18, ...]", 56 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
@@ -87,15 +87,26 @@ class Conll2003NERPipe(_NERPipe): | |||
经过该Pipe过后,DataSet中的内容如下所示 | |||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||
:header: "raw_words", "words", "target", "seq_len" | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4,...]", 6 | |||
"[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4,...]", "[4, 5, 6,...]", 6 | |||
"[...]", "[...]", "[...]", . | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths) -> DataBundle: | |||
@@ -112,17 +123,28 @@ class Conll2003NERPipe(_NERPipe): | |||
class Conll2003Pipe(Pipe): | |||
r""" | |||
""" | |||
经过该Pipe后,DataSet中的内容如下 | |||
.. csv-table:: | |||
:header: "raw_words" , "words", "pos", "chunk", "ner", "seq_len" | |||
:header: "raw_words" , "pos", "chunk", "ner", "words", "seq_len" | |||
"[Nadim, Ladki]", "[2, 3]", "[0, 0]", "[1, 2]", "[1, 2]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", 6 | |||
"[Nadim, Ladki]", "[0, 0]", "[1, 2]", "[1, 2]", "[2, 3]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", "[4, 5, 6,...]", 6 | |||
"[...]", "[...]", "[...]", "[...]", "[...]", . | |||
其中words, seq_len是input; pos, chunk, ner, seq_len是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||
| field_names | raw_words | pos | chunk | ner | words | seq_len | | |||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||
| is_input | False | False | False | False | True | True | | |||
| is_target | False | True | True | True | False | True | | |||
| ignore_type | | False | False | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | 0 | 0 | | |||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||
""" | |||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||
@@ -202,15 +224,26 @@ class OntoNotesNERPipe(_NERPipe): | |||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | |||
.. csv-table:: | |||
:header: "raw_words", "words", "target", "seq_len" | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4]", 6 | |||
"[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4]", "[4, 5, 6,...]", 6 | |||
"[...]", "[...]", "[...]", . | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths): | |||
@@ -306,15 +339,26 @@ class MsraNERPipe(_CNNERPipe): | |||
处理MSRA-NER的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 | |||
"[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 | |||
"[...]", "[...]", "[...]", . | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
@@ -327,14 +371,26 @@ class PeopleDailyPipe(_CNNERPipe): | |||
处理people daily的ner的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 | |||
"[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 | |||
"[...]", "[...]", "[...]", . | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
@@ -349,13 +405,24 @@ class WeiboNERPipe(_CNNERPipe): | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 | |||
"['老', '百', '姓']", "[4, 3, 3]", "[38, 39, 40]", 3 | |||
"['心']", "[0]", "[41]", 1 | |||
"[...]", "[...]", "[...]", . | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
@@ -18,9 +18,29 @@ from ...core.const import Const | |||
class CoReferencePipe(Pipe): | |||
""" | |||
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | |||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||
.. csv-table:: | |||
:header: "words1", "words2","words3","words4","chars","seq_len","target" | |||
"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self,config): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
@@ -35,14 +55,6 @@ class CoReferencePipe(Pipe): | |||
"bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||
"[...]", "[...]","[...]","[...]" | |||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||
.. csv-table:: | |||
:header: "words1", "words2","words3","words4","chars","seq_len","target" | |||
"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" | |||
:param data_bundle: | |||
:return: | |||
@@ -138,13 +138,22 @@ class CWSPipe(Pipe): | |||
对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "chars", "target", "bigrams", "trigrams", "seq_len" | |||
:header: "raw_words", "chars", "target", "seq_len" | |||
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", "[10, 4, 1,...]","[6, 4, 1,...]", 13 | |||
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", "[11, 12, ...]","[3, 9, ...]", 20 | |||
"...", "[...]","[...]", "[...]","[...]", . | |||
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", 13 | |||
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", 20 | |||
"...", "[...]","[...]", . | |||
其中bigrams仅当bigrams列为True的时候存在 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+-------+--------+---------+ | |||
| field_names | raw_words | chars | target | seq_len | | |||
+-------------+-----------+-------+--------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | False | True | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+-------+--------+---------+ | |||
""" | |||
@@ -7,7 +7,7 @@ __all__ = [ | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"XNLIBertPipe", | |||
"CNXNLIBertPipe", | |||
"BQCorpusBertPipe", | |||
"LCQMCBertPipe", | |||
"MatchingPipe", | |||
@@ -16,7 +16,7 @@ __all__ = [ | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
"XNLIPipe", | |||
"CNXNLIPipe", | |||
"BQCorpusPipe", | |||
"LCQMCPipe", | |||
] | |||
@@ -25,7 +25,7 @@ import warnings | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, XNLILoader, LCQMCLoader | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, LCQMCLoader | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ...core._logger import logger | |||
@@ -37,16 +37,27 @@ class MatchingBertPipe(Pipe): | |||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "words", "target", "seq_len" | |||
:header: "raw_words1", "raw_words2", "target", "words", "seq_len" | |||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10 | |||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5 | |||
"...", "...", "[...]", ., . | |||
"The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", 10 | |||
"This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", 5 | |||
"...", "...", ., "[...]", . | |||
words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 | |||
words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, | |||
如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+------------+------------+--------+-------+---------+ | |||
| field_names | raw_words1 | raw_words2 | target | words | seq_len | | |||
+-------------+------------+------------+--------+-------+---------+ | |||
| is_input | False | False | False | True | True | | |||
| is_target | False | False | True | False | False | | |||
| ignore_type | | | False | False | False | | |||
| pad_value | | | 0 | 0 | 0 | | |||
+-------------+------------+------------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
@@ -75,6 +86,18 @@ class MatchingBertPipe(Pipe): | |||
return data_bundle | |||
def process(self, data_bundle): | |||
""" | |||
输入的data_bundle中的dataset需要具有以下结构: | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" | |||
"...","..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field(Const.TARGET): | |||
dataset.drop(lambda x: x[Const.TARGET] == '-') | |||
@@ -178,15 +201,27 @@ class MatchingPipe(Pipe): | |||
Matching任务的Pipe。输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2" | |||
:header: "raw_words1", "raw_words2", "target", "words1", "words2", "seq_len1", "seq_len2" | |||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13 | |||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 | |||
"...", "...", "[...]", "[...]", ., ., . | |||
"The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", "[10, 20, 6]", 10, 13 | |||
"This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", "[2, 7, ...]", 6, 7 | |||
"...", "...", ., "[...]", "[...]", ., . | |||
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target | |||
和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 | |||
的形参名进行传参)。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||
| field_names | raw_words1 | raw_words2 | target | words1 | words2 | seq_len1 | seq_len2 | | |||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||
| is_input | False | False | False | True | True | True | True | | |||
| is_target | False | False | True | False | False | False | False | | |||
| ignore_type | | | False | False | False | False | False | | |||
| pad_value | | | 0 | 0 | 0 | 0 | 0 | | |||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
@@ -319,10 +354,10 @@ class LCQMCPipe(MatchingPipe): | |||
return data_bundle | |||
class XNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths = None): | |||
data_bundle = XNLILoader().load(paths) | |||
data_bundle = GranularizePipe(task = 'XNLI').process(data_bundle) | |||
class CNXNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = CNXNLILoader().load(paths) | |||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) #使中文数据的field | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
@@ -438,9 +473,9 @@ class BQCorpusBertPipe(MatchingBertPipe): | |||
return data_bundle | |||
class XNLIBertPipe(MatchingBertPipe): | |||
class CNXNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths = None): | |||
data_bundle = XNLILoader().load(paths) | |||
data_bundle = CNXNLILoader().load(paths) | |||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
@@ -1,15 +1,14 @@ | |||
"""undocumented""" | |||
import os | |||
import numpy as np | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||
from ..loader.json import JsonLoader | |||
from .utils import _drop_empty_instance | |||
from ..loader.summarization import ExtCNNDMLoader | |||
from ..data_bundle import DataBundle | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import Vocabulary | |||
from ...core._logger import logger | |||
WORD_PAD = "[PAD]" | |||
@@ -18,7 +17,6 @@ DOMAIN_UNK = "X" | |||
TAG_UNK = "X" | |||
class ExtCNNDMPipe(Pipe): | |||
""" | |||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | |||
@@ -27,13 +25,13 @@ class ExtCNNDMPipe(Pipe): | |||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | |||
""" | |||
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False): | |||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||
""" | |||
:param vocab_size: int, 词表大小 | |||
:param vocab_path: str, 外部词表路径 | |||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | |||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | |||
:param vocab_path: str, 外部词表路径 | |||
:param domain: bool, 是否需要建立domain词表 | |||
""" | |||
self.vocab_size = vocab_size | |||
@@ -42,8 +40,7 @@ class ExtCNNDMPipe(Pipe): | |||
self.doc_max_timesteps = doc_max_timesteps | |||
self.domain = domain | |||
def process(self, db: DataBundle): | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
传入的DataSet应该具备如下的结构 | |||
@@ -64,24 +61,28 @@ class ExtCNNDMPipe(Pipe): | |||
[[""],...,[""]], [[],...,[]], [], [] | |||
""" | |||
db.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||
if self.vocab_path is None: | |||
error_msg = 'vocab file is not defined!' | |||
logger.error(error_msg) | |||
raise RuntimeError(error_msg) | |||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | |||
# pad document | |||
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||
db = _drop_empty_instance(db, "label") | |||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||
# set input and target | |||
db.set_input(Const.INPUT, Const.INPUT_LEN) | |||
db.set_target(Const.TARGET, Const.INPUT_LEN) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN) | |||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||
word_list = [] | |||
@@ -96,47 +97,52 @@ class ExtCNNDMPipe(Pipe): | |||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.add_word_lst(word_list) | |||
vocabs.build_vocab() | |||
db.set_vocab(vocabs, "vocab") | |||
data_bundle.set_vocab(vocabs, "vocab") | |||
if self.domain == True: | |||
if self.domain is True: | |||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
domaindict.from_dataset(db.get_dataset("train"), field_name="publication") | |||
db.set_vocab(domaindict, "domain") | |||
return db | |||
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") | |||
data_bundle.set_vocab(domaindict, "domain") | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param paths: dict or string | |||
:return: DataBundle | |||
""" | |||
db = DataBundle() | |||
if isinstance(paths, dict): | |||
for key, value in paths.items(): | |||
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key) | |||
else: | |||
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test') | |||
self.process(db) | |||
:param paths: dict or string | |||
:return: DataBundle | |||
""" | |||
loader = ExtCNNDMLoader() | |||
if self.vocab_path is None: | |||
if paths is None: | |||
paths = loader.download() | |||
if not os.path.isdir(paths): | |||
error_msg = 'vocab file is not defined!' | |||
logger.error(error_msg) | |||
raise RuntimeError(error_msg) | |||
self.vocab_path = os.path.join(paths, 'vocab') | |||
db = loader.load(paths=paths) | |||
db = self.process(db) | |||
for ds in db.datasets.values(): | |||
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
return db | |||
def _lower_text(text_list): | |||
return [text.lower() for text in text_list] | |||
def _split_list(text_list): | |||
return [text.split() for text in text_list] | |||
def _convert_label(label, sent_len): | |||
np_label = np.zeros(sent_len, dtype=int) | |||
if label != []: | |||
np_label[np.array(label)] = 1 | |||
return np_label.tolist() | |||
def _pad_sent(text_wd, sent_max_len): | |||
pad_text_wd = [] | |||
for sent_wd in text_wd: | |||
@@ -148,6 +154,7 @@ def _pad_sent(text_wd, sent_max_len): | |||
pad_text_wd.append(sent_wd) | |||
return pad_text_wd | |||
def _token_mask(text_wd, sent_max_len): | |||
token_mask_list = [] | |||
for sent_wd in text_wd: | |||
@@ -159,6 +166,7 @@ def _token_mask(text_wd, sent_max_len): | |||
token_mask_list.append(mask) | |||
return token_mask_list | |||
def _pad_label(label, doc_max_timesteps): | |||
text_len = len(label) | |||
if text_len < doc_max_timesteps: | |||
@@ -167,6 +175,7 @@ def _pad_label(label, doc_max_timesteps): | |||
pad_label = label[:doc_max_timesteps] | |||
return pad_label | |||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
@@ -176,6 +185,7 @@ def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||
pad_text = text_wd[:doc_max_timesteps] | |||
return pad_text | |||
def _sent_mask(text_wd, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
@@ -22,6 +22,9 @@ class BaseModel(torch.nn.Module): | |||
class NaiveClassifier(BaseModel): | |||
""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
@@ -1,5 +1,6 @@ | |||
"""undocumented | |||
Variational RNN 的 Pytorch 实现 | |||
Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | |||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
__all__ = [ | |||
@@ -227,6 +228,7 @@ class VarRNNBase(nn.Module): | |||
class VarLSTM(VarRNNBase): | |||
""" | |||
Variational Dropout LSTM. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
@@ -253,7 +255,8 @@ class VarLSTM(VarRNNBase): | |||
class VarRNN(VarRNNBase): | |||
""" | |||
Variational Dropout RNN. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
@@ -279,7 +282,8 @@ class VarRNN(VarRNNBase): | |||
class VarGRU(VarRNNBase): | |||
""" | |||
Variational Dropout GRU. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
@@ -1,39 +1,35 @@ | |||
import os | |||
import tempfile | |||
import unittest | |||
import numpy as np | |||
import torch | |||
import os | |||
import shutil | |||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
LRFinder, TensorboardCallback | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import BCELoss | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import BCELoss | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import SGD | |||
from fastNLP import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier | |||
from fastNLP.core.callback import EarlyStopError | |||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
LRFinder, TensorboardCallback | |||
from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback | |||
from fastNLP.core.callback import WarmupCallback | |||
import tempfile | |||
from fastNLP.models.base_model import NaiveClassifier | |||
def prepare_env(): | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
model = NaiveClassifier(2, 1) | |||
@@ -43,11 +39,11 @@ def prepare_env(): | |||
class TestCallback(unittest.TestCase): | |||
def setUp(self): | |||
self.tempdir = tempfile.mkdtemp() | |||
def tearDown(self): | |||
pass | |||
# shutil.rmtree(self.tempdir) | |||
def test_gradient_clip(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
@@ -100,7 +96,7 @@ class TestCallback(unittest.TestCase): | |||
path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time)) | |||
if os.path.exists(path): | |||
shutil.rmtree(path) | |||
def test_readonly_property(self): | |||
from fastNLP.core.callback import Callback | |||
passed_epochs = [] | |||
@@ -123,19 +119,19 @@ class TestCallback(unittest.TestCase): | |||
check_code_level=2) | |||
trainer.train() | |||
assert passed_epochs == list(range(1, total_epochs + 1)) | |||
def test_evaluate_callback(self): | |||
data_set, model = prepare_env() | |||
from fastNLP import Tester | |||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | |||
evaluate_callback = EvaluateCallback(data_set, tester) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
callbacks=evaluate_callback, check_code_level=2) | |||
trainer.train() | |||
def test_fitlog_callback(self): | |||
import fitlog | |||
fitlog.set_log_dir(self.tempdir) | |||
@@ -143,13 +139,13 @@ class TestCallback(unittest.TestCase): | |||
from fastNLP import Tester | |||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | |||
fitlog_callback = FitlogCallback(data_set, tester) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=fitlog_callback, check_code_level=2) | |||
trainer.train() | |||
def test_save_model_callback(self): | |||
data_set, model = prepare_env() | |||
top = 3 | |||
@@ -159,10 +155,10 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=save_model_callback, check_code_level=2) | |||
trainer.train() | |||
timestamp = os.listdir(self.tempdir)[0] | |||
self.assertEqual(len(os.listdir(os.path.join(self.tempdir, timestamp))), top) | |||
def test_warmup_callback(self): | |||
data_set, model = prepare_env() | |||
warmup_callback = WarmupCallback() | |||
@@ -171,3 +167,50 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=warmup_callback, check_code_level=2) | |||
trainer.train() | |||
def test_early_stop_callback(self): | |||
""" | |||
需要观察是否真的 EarlyStop | |||
""" | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=2, n_epochs=10, print_every=5, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=EarlyStopCallback(1), check_code_level=2) | |||
trainer.train() | |||
def test_control_C(): | |||
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 | |||
from fastNLP import ControlC, Callback | |||
import time | |||
line1 = "\n\n\n\n\n*************************" | |||
line2 = "*************************\n\n\n\n\n" | |||
class Wait(Callback): | |||
def on_epoch_end(self): | |||
time.sleep(5) | |||
data_set, model = prepare_env() | |||
print(line1 + "Test starts!" + line2) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=20, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=[Wait(), ControlC(False)], check_code_level=2) | |||
trainer.train() | |||
print(line1 + "Program goes on ..." + line2) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=20, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=[Wait(), ControlC(True)], check_code_level=2) | |||
trainer.train() | |||
print(line1 + "Test failed!" + line2) | |||
if __name__ == "__main__": | |||
test_control_C() |
@@ -1,33 +1,36 @@ | |||
import os | |||
import shutil | |||
import subprocess | |||
import unittest | |||
from argparse import ArgumentParser | |||
import numpy as np | |||
import torch.cuda | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import CrossEntropyLoss, BCELoss | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import CrossEntropyLoss, BCELoss | |||
from fastNLP import SGD | |||
from fastNLP.core.callback import EchoCallback | |||
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | |||
from fastNLP.models.base_model import NaiveClassifier | |||
import shutil | |||
import os | |||
import subprocess | |||
from argparse import ArgumentParser | |||
from fastNLP.core.callback import EchoCallback | |||
from fastNLP import AccuracyMetric | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B]) | |||
return data_set | |||
def prepare_fake_dataset2(*args, size=100): | |||
ys = np.random.randint(4, size=100, dtype=np.int64) | |||
data = {'y': ys} | |||
@@ -35,32 +38,35 @@ def prepare_fake_dataset2(*args, size=100): | |||
data[arg] = np.random.randn(size, 5) | |||
return DataSet(data=data) | |||
def set_rng_seed(seed): | |||
np.random.seed(seed) | |||
def prepare_env(): | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
model = NaiveClassifier(2, 1) | |||
return data_set, model | |||
class TestDistTrainer(unittest.TestCase): | |||
save_path = './save_cp' | |||
def run1(self): | |||
# test distributed training | |||
print('local rank', get_local_rank()) | |||
@@ -68,9 +74,9 @@ class TestDistTrainer(unittest.TestCase): | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
data_set.set_target("y", flag=True) | |||
model = NaiveClassifier(2, 2) | |||
trainer = DistTrainer( | |||
model=model, train_data=data_set, optimizer=SGD(lr=0.1), | |||
loss=CrossEntropyLoss(pred="predict", target="y"), | |||
@@ -82,7 +88,7 @@ class TestDistTrainer(unittest.TestCase): | |||
""" | |||
if trainer.is_master and os.path.exists(self.save_path): | |||
shutil.rmtree(self.save_path) | |||
def run2(self): | |||
# test fp16 with distributed training | |||
print('local rank', get_local_rank()) | |||
@@ -90,9 +96,9 @@ class TestDistTrainer(unittest.TestCase): | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
data_set.set_target("y", flag=True) | |||
model = NaiveClassifier(2, 2) | |||
trainer = DistTrainer( | |||
model=model, train_data=data_set, optimizer=SGD(lr=0.1), | |||
loss=CrossEntropyLoss(pred="predict", target="y"), | |||
@@ -105,7 +111,7 @@ class TestDistTrainer(unittest.TestCase): | |||
""" | |||
if trainer.is_master and os.path.exists(self.save_path): | |||
shutil.rmtree(self.save_path) | |||
def run3(self): | |||
set_rng_seed(100) | |||
data_set, model = prepare_env() | |||
@@ -117,26 +123,28 @@ class TestDistTrainer(unittest.TestCase): | |||
callbacks_master=[EchoCallback('callbacks_master')] | |||
) | |||
trainer.train() | |||
def run4(self): | |||
set_rng_seed(100) | |||
data_set, model = prepare_env() | |||
train_set, dev_set = data_set.split(0.3) | |||
model = NaiveClassifier(2, 1) | |||
trainer = DistTrainer( | |||
train_set, model, optimizer=SGD(lr=0.1), | |||
loss=BCELoss(pred="predict", target="y"), | |||
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=self.save_path, | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
if trainer.is_master and os.path.exists(self.save_path): | |||
shutil.rmtree(self.save_path) | |||
def run_dist(self, run_id): | |||
if torch.cuda.is_available(): | |||
ngpu = min(2, torch.cuda.device_count()) | |||
@@ -145,23 +153,24 @@ class TestDistTrainer(unittest.TestCase): | |||
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | |||
print(' '.join(cmd)) | |||
subprocess.check_call(cmd) | |||
def test_normal_run(self): | |||
self.run_dist(1) | |||
def no_test_fp16(self): | |||
self.run_dist(2) | |||
def test_callback(self): | |||
self.run_dist(3) | |||
def test_dev_data(self): | |||
self.run_dist(4) | |||
if __name__ == '__main__': | |||
runner = TestDistTrainer() | |||
parser = ArgumentParser() | |||
parser.add_argument('--test', type=int) | |||
args, _ = parser.parse_known_args() | |||
if args.test and hasattr(runner, 'run%s'%args.test): | |||
getattr(runner, 'run%s'%args.test)() | |||
if args.test and hasattr(runner, 'run%s' % args.test): | |||
getattr(runner, 'run%s' % args.test)() |
@@ -5,7 +5,7 @@ import os | |||
from fastNLP.io import DataBundle | |||
from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \ | |||
BQCorpusLoader, XNLILoader, LCQMCLoader | |||
BQCorpusLoader, CNXNLILoader, LCQMCLoader | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
@@ -31,7 +31,7 @@ class TestMatchingLoad(unittest.TestCase): | |||
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | |||
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | |||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | |||
'XNLI': ('test/data_for_tests/io/XNLI', XNLILoader, (6, 7, 6), False), | |||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 7, 6), False), | |||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | |||
} | |||
for k, v in data_set_dict.items(): | |||
@@ -4,9 +4,9 @@ import os | |||
from fastNLP.io import DataBundle | |||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \ | |||
XNLIPipe, BQCorpusPipe, LCQMCPipe | |||
CNXNLIPipe, BQCorpusPipe, LCQMCPipe | |||
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \ | |||
XNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe | |||
CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
@@ -38,7 +38,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | |||
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||
'XNLI': ('test/data_for_tests/io/XNLI', XNLIPipe, XNLIBertPipe, (6, 7, 6), (37, 3), False), | |||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 7, 6), (37, 3), False), | |||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | |||
} | |||
for k, v in data_set_dict.items(): | |||
@@ -1,59 +1,69 @@ | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
import unittest | |||
import os | |||
# import sys | |||
# | |||
# sys.path.append("../../../") | |||
from fastNLP.io import DataBundle | |||
from fastNLP.io.pipe.summarization import ExtCNNDMPipe | |||
class TestRunExtCNNDMPipe(unittest.TestCase): | |||
def test_load(self): | |||
data_set_dict = { | |||
'CNNDM': {"train": 'test/data_for_tests/cnndm.jsonl'}, | |||
} | |||
vocab_size = 100000 | |||
VOCAL_FILE = 'test/data_for_tests/cnndm.vocab' | |||
sent_max_len = 100 | |||
doc_max_timesteps = 50 | |||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps) | |||
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, | |||
domain=True) | |||
for k, v in data_set_dict.items(): | |||
db = dbPipe.process_from_file(v) | |||
db2 = dbPipe2.process_from_file(v) | |||
# print(db2.get_dataset("train")) | |||
self.assertTrue(isinstance(db, DataBundle)) | |||
self.assertTrue(isinstance(db2, DataBundle)) | |||
#!/usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# __author__="Danqing Wang" | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
import unittest | |||
import os | |||
from fastNLP.io import DataBundle | |||
from fastNLP.io.pipe.summarization import ExtCNNDMPipe | |||
class TestRunExtCNNDMPipe(unittest.TestCase): | |||
def test_load(self): | |||
data_dir = 'test/data_for_tests/io/cnndm' | |||
vocab_size = 100000 | |||
VOCAL_FILE = 'test/data_for_tests/io/cnndm/vocab' | |||
sent_max_len = 100 | |||
doc_max_timesteps = 50 | |||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps) | |||
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, | |||
domain=True) | |||
db = dbPipe.process_from_file(data_dir) | |||
db2 = dbPipe2.process_from_file(data_dir) | |||
self.assertTrue(isinstance(db, DataBundle)) | |||
self.assertTrue(isinstance(db2, DataBundle)) | |||
dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, | |||
domain=True) | |||
db3 = dbPipe3.process_from_file(data_dir) | |||
self.assertTrue(isinstance(db3, DataBundle)) | |||
with self.assertRaises(RuntimeError): | |||
dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps) | |||
db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||
dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps,) | |||
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||
self.assertIsInstance(db5, DataBundle) | |||
@@ -0,0 +1,25 @@ | |||
import os | |||
import unittest | |||
from fastNLP.io import ModelSaver, ModelLoader | |||
from fastNLP.models import CNNText | |||
class TestModelIO(unittest.TestCase): | |||
def test_save_and_load(self): | |||
model = CNNText((10, 10), 2) | |||
saver = ModelSaver('tmp') | |||
loader = ModelLoader() | |||
saver.save_pytorch(model) | |||
new_cnn = CNNText((10, 10), 2) | |||
loader.load_pytorch(new_cnn, 'tmp') | |||
new_model = loader.load_pytorch_model('tmp') | |||
for i in range(10): | |||
for j in range(10): | |||
self.assertEqual(model.embed.embed.weight[i, j], new_cnn.embed.embed.weight[i, j]) | |||
self.assertEqual(model.embed.embed.weight[i, j], new_model["embed.embed.weight"][i, j]) | |||
os.system('rm tmp') |
@@ -1,9 +1,20 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.modules.utils import get_dropout_mask | |||
from fastNLP.models import CNNText | |||
from fastNLP.modules.utils import get_dropout_mask, summary | |||
class TestUtil(unittest.TestCase): | |||
def test_get_dropout_mask(self): | |||
tensor = torch.randn(3, 4) | |||
mask = get_dropout_mask(0.3, tensor) | |||
self.assertSequenceEqual(mask.size(), torch.Size([3, 4])) | |||
self.assertSequenceEqual(mask.size(), torch.Size([3, 4])) | |||
def test_summary(self): | |||
model = CNNText(embed=(4, 4), num_classes=2, kernel_nums=(9,5), kernel_sizes=(1,3)) | |||
# 4 * 4 + 4 * (9 * 1 + 5 * 3) + 2 * (9 + 5 + 1) = 142 | |||
self.assertSequenceEqual((142, 142, 0), summary(model)) | |||
model.embed.requires_grad = False | |||
self.assertSequenceEqual((142, 126, 16), summary(model)) |