@@ -2,6 +2,6 @@ fastNLP.core.callback | |||||
===================== | ===================== | ||||
.. automodule:: fastNLP.core.callback | .. automodule:: fastNLP.core.callback | ||||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError | |||||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError | |||||
:inherited-members: | :inherited-members: | ||||
@@ -2,5 +2,5 @@ fastNLP.modules.encoder | |||||
======================= | ======================= | ||||
.. automodule:: fastNLP.modules.encoder | .. automodule:: fastNLP.modules.encoder | ||||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, BiAttention, SelfAttention | |||||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, BiAttention, SelfAttention | |||||
@@ -2,7 +2,7 @@ fastNLP.modules | |||||
=============== | =============== | ||||
.. automodule:: fastNLP.modules | .. automodule:: fastNLP.modules | ||||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout | |||||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout | |||||
子模块 | 子模块 | ||||
------ | ------ | ||||
@@ -2,7 +2,7 @@ fastNLP | |||||
======= | ======= | ||||
.. automodule:: fastNLP | .. automodule:: fastNLP | ||||
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger | |||||
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger | |||||
:inherited-members: | :inherited-members: | ||||
子模块 | 子模块 | ||||
@@ -36,8 +36,6 @@ __all__ = [ | |||||
"TensorboardCallback", | "TensorboardCallback", | ||||
"WarmupCallback", | "WarmupCallback", | ||||
'SaveModelCallback', | 'SaveModelCallback', | ||||
"EchoCallback", | |||||
"TesterCallback", | |||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError", | "EarlyStopError", | ||||
@@ -48,8 +48,6 @@ __all__ = [ | |||||
"TensorboardCallback", | "TensorboardCallback", | ||||
"WarmupCallback", | "WarmupCallback", | ||||
'SaveModelCallback', | 'SaveModelCallback', | ||||
"EchoCallback", | |||||
"TesterCallback", | |||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError", | "EarlyStopError", | ||||
@@ -78,8 +76,8 @@ __all__ = [ | |||||
from ._logger import logger | from ._logger import logger | ||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | from .batch import DataSetIter, BatchIter, TorchLoaderIter | ||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | ||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | |||||
TesterCallback, CallbackException, EarlyStopError | |||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ | |||||
EarlyStopError | |||||
from .const import Const | from .const import Const | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
@@ -193,13 +193,14 @@ class DataSetIter(BatchIter): | |||||
Default: ``None`` | Default: ``None`` | ||||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | ||||
Default: ``False`` | Default: ``False`` | ||||
:param int num_workers: 使用多少个进程来预处理数据 | :param int num_workers: 使用多少个进程来预处理数据 | ||||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | :param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | ||||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | ||||
:param timeout: | |||||
:param timeout: 生成一个batch的timeout值 | |||||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | ||||
:param collate_fn: 用于将样本组合成batch的函数 | |||||
""" | """ | ||||
assert isinstance(dataset, DataSet) | assert isinstance(dataset, DataSet) | ||||
dataset = DataSetGetter(dataset, as_numpy) | dataset = DataSetGetter(dataset, as_numpy) | ||||
@@ -220,12 +221,26 @@ class DataSetIter(BatchIter): | |||||
class TorchLoaderIter(BatchIter): | class TorchLoaderIter(BatchIter): | ||||
""" | """ | ||||
与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||||
与DataSetIter类似,但用于pytorch的DataSet对象。 | |||||
通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None, collate_fn=None): | timeout=0, worker_init_fn=None, collate_fn=None): | ||||
""" | |||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||||
:param int batch_size: 取出的batch大小 | |||||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||||
Default: ``None`` | |||||
:param int num_workers: 使用多少个进程来预处理数据 | |||||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||||
:param timeout: 生成一个batch的timeout值 | |||||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||||
:param collate_fn: 用于将样本组合成batch的函数""" | |||||
assert len(dataset) > 0 | assert len(dataset) > 0 | ||||
ins = dataset[0] | ins = dataset[0] | ||||
assert len(ins) == 2 and \ | assert len(ins) == 2 and \ | ||||
@@ -62,8 +62,6 @@ __all__ = [ | |||||
"TensorboardCallback", | "TensorboardCallback", | ||||
"WarmupCallback", | "WarmupCallback", | ||||
"SaveModelCallback", | "SaveModelCallback", | ||||
"EchoCallback", | |||||
"TesterCallback", | |||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError" | "EarlyStopError" | ||||
@@ -87,12 +85,18 @@ except: | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .tester import Tester | from .tester import Tester | ||||
from ._logger import logger | from ._logger import logger | ||||
from .utils import _check_fp16 | |||||
try: | try: | ||||
import fitlog | import fitlog | ||||
except: | except: | ||||
pass | pass | ||||
try: | |||||
from apex import amp | |||||
except: | |||||
amp = None | |||||
class Callback(object): | class Callback(object): | ||||
""" | """ | ||||
@@ -269,14 +273,6 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_validation(self): | |||||
""" | |||||
如果Trainer中设置了验证,则会在每次需要验证时调用该函数 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
""" | """ | ||||
@@ -470,7 +466,7 @@ class GradientClipCallback(Callback): | |||||
if self.step%self.update_every==0: | if self.step%self.update_every==0: | ||||
if self.parameters is None: | if self.parameters is None: | ||||
if getattr(self.trainer, 'fp16', ''): | if getattr(self.trainer, 'fp16', ''): | ||||
from apex import amp | |||||
_check_fp16() | |||||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | ||||
self.clip_fun(self.model.parameters(), self.clip_value) | self.clip_fun(self.model.parameters(), self.clip_value) | ||||
else: | else: | ||||
@@ -713,6 +709,8 @@ class ControlC(Callback): | |||||
class SmoothValue(object): | class SmoothValue(object): | ||||
"""work for LRFinder""" | |||||
def __init__(self, beta: float): | def __init__(self, beta: float): | ||||
self.beta, self.n, self.mov_avg = beta, 0, 0 | self.beta, self.n, self.mov_avg = beta, 0, 0 | ||||
self.smooth = None | self.smooth = None | ||||
@@ -1025,6 +1023,10 @@ class EarlyStopError(CallbackException): | |||||
class EchoCallback(Callback): | class EchoCallback(Callback): | ||||
""" | |||||
用于测试分布式训练 | |||||
""" | |||||
def __init__(self, name, out=sys.stdout): | def __init__(self, name, out=sys.stdout): | ||||
super(EchoCallback, self).__init__() | super(EchoCallback, self).__init__() | ||||
self.name = name | self.name = name | ||||
@@ -1036,27 +1038,23 @@ class EchoCallback(Callback): | |||||
return super(EchoCallback, self).__getattribute__(item) | return super(EchoCallback, self).__getattribute__(item) | ||||
class TesterCallback(Callback): | |||||
class _TesterCallback(Callback): | |||||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): | def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): | ||||
super(TesterCallback, self).__init__() | |||||
super(_TesterCallback, self).__init__() | |||||
if hasattr(model, 'module'): | if hasattr(model, 'module'): | ||||
# for data parallel model | # for data parallel model | ||||
model = model.module | model = model.module | ||||
self.tester = Tester(data, model, | self.tester = Tester(data, model, | ||||
metrics=metrics, batch_size=batch_size, | metrics=metrics, batch_size=batch_size, | ||||
num_workers=num_workers, verbose=0) | num_workers=num_workers, verbose=0) | ||||
# parse metric_key | |||||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||||
# It is true by default. | |||||
self.increase_better = True | |||||
if metric_key is not None: | if metric_key is not None: | ||||
self.increase_better = False if metric_key[0] == "-" else True | |||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||||
self.metric_key, self.increase_better = self._parse_metric_key(metric_key) | |||||
else: | else: | ||||
self.metric_key = None | self.metric_key = None | ||||
self.increase_better = True | |||||
self.score = None | self.score = None | ||||
def on_validation(self): | |||||
def on_valid_begin(self): | |||||
cur_score = self.tester.test() | cur_score = self.tester.test() | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | ||||
self.epoch, self.n_epochs, self.step, self.n_steps, | self.epoch, self.n_epochs, self.step, self.n_steps, | ||||
@@ -1067,17 +1065,28 @@ class TesterCallback(Callback): | |||||
self.score = cur_score | self.score = cur_score | ||||
return cur_score, is_better | return cur_score, is_better | ||||
def _get_score(self, metric_dict, key): | |||||
@staticmethod | |||||
def _get_score(metric_dict, key): | |||||
for metric in metric_dict.items(): | for metric in metric_dict.items(): | ||||
if key in metric: | if key in metric: | ||||
return metric[key] | return metric[key] | ||||
return None | return None | ||||
@staticmethod | |||||
def _parse_metric_key(metric_key): | |||||
# parse metric_key | |||||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||||
# It is true by default. | |||||
increase_better = False if metric_key[0] == "-" else True | |||||
metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||||
return metric_key, increase_better | |||||
def compare_better(self, a): | def compare_better(self, a): | ||||
if self.score is None: | if self.score is None: | ||||
return True | return True | ||||
if self.metric_key is None: | if self.metric_key is None: | ||||
self.metric_key = list(list(self.score.values())[0].keys())[0] | |||||
metric_key = list(list(self.score.values())[0].keys())[0] | |||||
self.metric_key, self.increase_better = self._parse_metric_key(metric_key) | |||||
k = self.metric_key | k = self.metric_key | ||||
score = self._get_score(self.score, k) | score = self._get_score(self.score, k) | ||||
new_score = self._get_score(a, k) | new_score = self._get_score(a, k) | ||||
@@ -1087,7 +1096,3 @@ class TesterCallback(Callback): | |||||
return score <= new_score | return score <= new_score | ||||
else: | else: | ||||
return score >= new_score | return score >= new_score | ||||
def on_train_end(self): | |||||
self.logger.info('Evaluate on training ends.') | |||||
self.on_validation() |
@@ -17,21 +17,31 @@ from tqdm import tqdm | |||||
from ._logger import logger | from ._logger import logger | ||||
from .batch import DataSetIter, BatchIter | from .batch import DataSetIter, BatchIter | ||||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||||
from .callback import DistCallbackManager, CallbackException | |||||
from .callback import _TesterCallback | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
from .utils import _build_args | from .utils import _build_args | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _move_dict_value_to_device | from .utils import _move_dict_value_to_device | ||||
from .utils import _check_fp16 | |||||
try: | |||||
from apex import amp | |||||
except: | |||||
amp = None | |||||
__all__ = [ | __all__ = [ | ||||
'get_local_rank', | 'get_local_rank', | ||||
'DistTrainer', | 'DistTrainer', | ||||
] | ] | ||||
def get_local_rank(): | def get_local_rank(): | ||||
""" | |||||
返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数 | |||||
""" | |||||
if 'LOCAL_RANK' in os.environ: | if 'LOCAL_RANK' in os.environ: | ||||
return int(os.environ['LOCAL_RANK']) | return int(os.environ['LOCAL_RANK']) | ||||
from argparse import ArgumentParser | from argparse import ArgumentParser | ||||
@@ -46,7 +56,10 @@ def get_local_rank(): | |||||
class DistTrainer(): | class DistTrainer(): | ||||
""" | """ | ||||
Distributed Trainer that support distributed and mixed precision training | |||||
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 | |||||
Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, | |||||
请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等) | |||||
""" | """ | ||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
callbacks_all=None, callbacks_master=None, | callbacks_all=None, callbacks_master=None, | ||||
@@ -55,8 +68,43 @@ class DistTrainer(): | |||||
dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
save_every=-1, save_path=None, device='auto', | save_every=-1, save_path=None, device='auto', | ||||
fp16='', backend=None, init_method=None): | |||||
fp16='', backend=None, init_method=None, use_tqdm=True): | |||||
""" | |||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||||
:param nn.modules model: 待训练的模型 | |||||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||||
:param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 | |||||
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||||
:param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。 | |||||
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||||
:param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。 | |||||
:param int n_epochs: 需要优化迭代多少次。 | |||||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||||
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 | |||||
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , | |||||
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 | |||||
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||||
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 | |||||
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, | |||||
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 | |||||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||||
:param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。 | |||||
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | |||||
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||||
:param str device: 指定 device,可以是 gpu,cpu 或 auto | |||||
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 | |||||
:param backend: 指定分布式的backend,详情参考 pytorch 文档 | |||||
:param init_method 指定分布式的初始化方法,详情参考 pytorch 文档 | |||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||||
""" | |||||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | ||||
if device == 'auto': | if device == 'auto': | ||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||||
@@ -94,7 +142,9 @@ class DistTrainer(): | |||||
self.callback_manager = DistCallbackManager( | self.callback_manager = DistCallbackManager( | ||||
env={"trainer": self}, callbacks_all=callbacks_all, | env={"trainer": self}, callbacks_all=callbacks_all, | ||||
callbacks_master=callbacks_master) | callbacks_master=callbacks_master) | ||||
self.test_manager = DistCallbackManager(env={'trainer': self}) | |||||
self.metric_key = metric_key | self.metric_key = metric_key | ||||
self.use_tqdm = use_tqdm | |||||
model.to(self.device) | model.to(self.device) | ||||
optimizer = self._get_optimizer(optimizer) | optimizer = self._get_optimizer(optimizer) | ||||
@@ -102,11 +152,7 @@ class DistTrainer(): | |||||
# init fp16, must before DataParallel init | # init fp16, must before DataParallel init | ||||
if len(self.fp16): | if len(self.fp16): | ||||
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" | assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" | ||||
try: | |||||
from apex import amp | |||||
except ImportError: | |||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||||
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | |||||
_check_fp16() | |||||
assert device == 'cuda', "Amp requires cuda device" | assert device == 'cuda', "Amp requires cuda device" | ||||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | ||||
@@ -121,20 +167,21 @@ class DistTrainer(): | |||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
self.sampler = DistributedSampler(self.train_data) | self.sampler = DistributedSampler(self.train_data) | ||||
self.data_iterator = self._get_data_iter(self.train_data) | self.data_iterator = self._get_data_iter(self.train_data) | ||||
self.batch_size = self.world_size * self.batch_size_per_gpu | |||||
self.n_steps = self._get_n_steps() | self.n_steps = self._get_n_steps() | ||||
# for evaluation, only run eval on master proc | # for evaluation, only run eval on master proc | ||||
if dev_data and metrics: | if dev_data and metrics: | ||||
cb = TesterCallback( | |||||
cb = _TesterCallback( | |||||
dev_data, model, metrics, | dev_data, model, metrics, | ||||
batch_size=batch_size_per_gpu, num_workers=num_workers) | batch_size=batch_size_per_gpu, num_workers=num_workers) | ||||
self.callback_manager.add_callback([cb], master=True) | |||||
self.test_manager.add_callback([cb], master=False) | |||||
# Setup logging | # Setup logging | ||||
dist.barrier() | dist.barrier() | ||||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | ||||
if self.save_path: | if self.save_path: | ||||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints') | |||||
else: | else: | ||||
self.cp_save_path = None | self.cp_save_path = None | ||||
@@ -178,9 +225,27 @@ class DistTrainer(): | |||||
@property | @property | ||||
def is_master(self): | def is_master(self): | ||||
"""是否是主进程""" | |||||
return self.rank == 0 | return self.rank == 0 | ||||
def train(self, on_exception='auto'): | |||||
def train(self, load_best_model=True, on_exception='auto'): | |||||
""" | |||||
使用该函数使Trainer开始训练。 | |||||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | |||||
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; | |||||
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. | |||||
:return dict: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称, | |||||
第二层的key为具体的Metric | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | |||||
try: | try: | ||||
self.logger.info("###### Training epochs started ######") | self.logger.info("###### Training epochs started ######") | ||||
self.logger.info('Total epochs: %d'% self.n_epochs) | self.logger.info('Total epochs: %d'% self.n_epochs) | ||||
@@ -222,17 +287,22 @@ class DistTrainer(): | |||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
self.logger.info("###### Train finished ######") | self.logger.info("###### Train finished ######") | ||||
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | ||||
return results | |||||
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): | |||||
self.load_check_point('best') | |||||
finally: | finally: | ||||
self.close() | |||||
pass | |||||
dist.barrier() | |||||
return results | |||||
def _train(self): | def _train(self): | ||||
if self.fp16: | |||||
# skip check, done in __init__() | |||||
from apex import amp | |||||
if not self.use_tqdm: | |||||
from .utils import _pseudo_tqdm as inner_tqdm | |||||
else: | |||||
inner_tqdm = tqdm | |||||
self.step = 0 | self.step = 0 | ||||
self.epoch = 0 | self.epoch = 0 | ||||
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||||
self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||||
leave=False, dynamic_ncols=True, disable=not self.is_master) | leave=False, dynamic_ncols=True, disable=not self.is_master) | ||||
pbar = self.pbar | pbar = self.pbar | ||||
avg_loss = 0 | avg_loss = 0 | ||||
@@ -292,8 +362,8 @@ class DistTrainer(): | |||||
if self.validate_every < 0: | if self.validate_every < 0: | ||||
self._do_validation() | self._do_validation() | ||||
if self.save_every < 0 and self.cp_save_path: | |||||
self.save_check_point() | |||||
if self.save_every < 0 and self.cp_save_path: | |||||
self.save_check_point() | |||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
@@ -327,30 +397,52 @@ class DistTrainer(): | |||||
loss = self.losser(predict, truth) | loss = self.losser(predict, truth) | ||||
if self.update_every > 1: | if self.update_every > 1: | ||||
loss = loss / self.update_every | loss = loss / self.update_every | ||||
return loss.mean() | |||||
if loss.dim() > 0: | |||||
loss = loss.mean() | |||||
return loss | |||||
def save_check_point(self, only_params=False): | |||||
def save_check_point(self, name=None, only_params=False): | |||||
"""保存当前模型""" | |||||
# only master save models | # only master save models | ||||
if self.is_master: | if self.is_master: | ||||
if name is None: | |||||
name = 'checkpoint-{}.bin'.format(self.step) | |||||
os.makedirs(self.cp_save_path, exist_ok=True) | os.makedirs(self.cp_save_path, exist_ok=True) | ||||
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||||
path = os.path.join(self.cp_save_path, name) | |||||
self.logger.info("Save checkpoint to {}".format(path)) | self.logger.info("Save checkpoint to {}".format(path)) | ||||
model_to_save = self.model.module | model_to_save = self.model.module | ||||
if only_params: | if only_params: | ||||
model_to_save = model_to_save.state_dict() | model_to_save = model_to_save.state_dict() | ||||
torch.save(model_to_save, path) | torch.save(model_to_save, path) | ||||
def load_check_point(self, name): | |||||
path = os.path.join(self.cp_save_path, name) | |||||
self.logger.info('reload best model from %s', path) | |||||
model_load = torch.load(path, map_location='cpu') | |||||
if not isinstance(model_load, dict): | |||||
model_load = model_load.state_dict() | |||||
self.model.module.load_state_dict(model_load) | |||||
def _do_validation(self): | def _do_validation(self): | ||||
self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
eval_res = self.callback_manager.on_validation() | |||||
# do evaluate on all nodes | |||||
eval_res = self.test_manager.on_valid_begin() | |||||
eval_res = list(filter(lambda x: x is not None, eval_res)) | eval_res = list(filter(lambda x: x is not None, eval_res)) | ||||
if len(eval_res): | if len(eval_res): | ||||
eval_res, is_better = list(zip(*eval_res)) | eval_res, is_better = list(zip(*eval_res)) | ||||
else: | else: | ||||
eval_res, is_better = None, None | eval_res, is_better = None, None | ||||
# save better model on master node | |||||
if self.is_master and is_better is not None and self.cp_save_path: | |||||
for i, better_flag in enumerate(is_better): | |||||
if better_flag: | |||||
# TODO to support multiple datasets to evaluate | |||||
self.save_check_point('best') | |||||
break | |||||
self.callback_manager.on_valid_end( | self.callback_manager.on_valid_end( | ||||
eval_res, self.metric_key, self.optimizer, is_better) | eval_res, self.metric_key, self.optimizer, is_better) | ||||
dist.barrier() | dist.barrier() | ||||
def close(self): | def close(self): | ||||
"""关闭Trainer,销毁进程""" | |||||
dist.destroy_process_group() | dist.destroy_process_group() |
@@ -829,12 +829,12 @@ class Trainer(object): | |||||
self.best_metric_indicator = indicator_val | self.best_metric_indicator = indicator_val | ||||
else: | else: | ||||
if self.increase_better is True: | if self.increase_better is True: | ||||
if indicator_val >= self.best_metric_indicator: | |||||
if indicator_val > self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | self.best_metric_indicator = indicator_val | ||||
else: | else: | ||||
is_better = False | is_better = False | ||||
else: | else: | ||||
if indicator_val <= self.best_metric_indicator: | |||||
if indicator_val < self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | self.best_metric_indicator = indicator_val | ||||
else: | else: | ||||
is_better = False | is_better = False | ||||
@@ -842,6 +842,7 @@ class Trainer(object): | |||||
@property | @property | ||||
def is_master(self): | def is_master(self): | ||||
"""是否是主进程""" | |||||
return True | return True | ||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
@@ -19,6 +19,10 @@ import torch.nn as nn | |||||
from typing import List | from typing import List | ||||
from ._logger import logger | from ._logger import logger | ||||
from prettytable import PrettyTable | from prettytable import PrettyTable | ||||
try: | |||||
from apex import amp | |||||
except: | |||||
amp = None | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -805,3 +809,10 @@ def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||||
if len(string) > avg: | if len(string) > avg: | ||||
string = string[:(avg - 3)] + "..." | string = string[:(avg - 3)] + "..." | ||||
return string | return string | ||||
def _check_fp16(): | |||||
if amp is None: | |||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||||
if not torch.backends.cudnn.enabled: | |||||
raise RuntimeError("Amp requires cudnn backend to be enabled.") |
@@ -1,4 +1,6 @@ | |||||
"""undocumented""" | |||||
"""undocumented | |||||
用于辅助生成 fastNLP 文档的代码 | |||||
""" | |||||
__all__ = [] | __all__ = [] | ||||
@@ -15,6 +17,9 @@ def doc_process(m): | |||||
pass | pass | ||||
else: | else: | ||||
module_name = obj.__module__ | module_name = obj.__module__ | ||||
# 识别并标注类和函数在不同层次中的位置 | |||||
while 1: | while 1: | ||||
defined_m = sys.modules[module_name] | defined_m = sys.modules[module_name] | ||||
if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: | if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: | ||||
@@ -25,6 +30,8 @@ def doc_process(m): | |||||
if module_name == m.__name__: | if module_name == m.__name__: | ||||
# print(name, ": not found defined doc.") | # print(name, ": not found defined doc.") | ||||
break | break | ||||
# 识别并标注基类,只有基类也在 fastNLP 中定义才显示 | |||||
if inspect.isclass(obj): | if inspect.isclass(obj): | ||||
for base in obj.__bases__: | for base in obj.__bases__: | ||||
@@ -25,6 +25,8 @@ __all__ = [ | |||||
'SSTLoader', | 'SSTLoader', | ||||
'SST2Loader', | 'SST2Loader', | ||||
"ChnSentiCorpLoader", | "ChnSentiCorpLoader", | ||||
"THUCNewsLoader", | |||||
"WeiboSenti100kLoader", | |||||
'ConllLoader', | 'ConllLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
@@ -45,6 +47,9 @@ __all__ = [ | |||||
"SNLILoader", | "SNLILoader", | ||||
"QNLILoader", | "QNLILoader", | ||||
"RTELoader", | "RTELoader", | ||||
"XNLILoader", | |||||
"BQCorpusLoader", | |||||
"LCQMCLoader", | |||||
"Pipe", | "Pipe", | ||||
@@ -54,6 +59,8 @@ __all__ = [ | |||||
"SST2Pipe", | "SST2Pipe", | ||||
"IMDBPipe", | "IMDBPipe", | ||||
"ChnSentiCorpPipe", | "ChnSentiCorpPipe", | ||||
"THUCNewsPipe", | |||||
"WeiboSenti100kPipe", | |||||
"Conll2003Pipe", | "Conll2003Pipe", | ||||
"Conll2003NERPipe", | "Conll2003NERPipe", | ||||
@@ -83,25 +83,41 @@ PRETRAIN_STATIC_FILES = { | |||||
} | } | ||||
DATASET_DIR = { | DATASET_DIR = { | ||||
# Classification, English | |||||
'aclImdb': "imdb.zip", | 'aclImdb': "imdb.zip", | ||||
"yelp-review-full": "yelp_review_full.tar.gz", | "yelp-review-full": "yelp_review_full.tar.gz", | ||||
"yelp-review-polarity": "yelp_review_polarity.tar.gz", | "yelp-review-polarity": "yelp_review_polarity.tar.gz", | ||||
"sst-2": "SST-2.zip", | |||||
"sst": "SST.zip", | |||||
# Classification, Chinese | |||||
"chn-senti-corp": "chn_senti_corp.zip", | |||||
"weibo-senti-100k": "WeiboSenti100k.zip", | |||||
"thuc-news": "THUCNews.zip", | |||||
# Matching, English | |||||
"mnli": "MNLI.zip", | "mnli": "MNLI.zip", | ||||
"snli": "SNLI.zip", | "snli": "SNLI.zip", | ||||
"qnli": "QNLI.zip", | "qnli": "QNLI.zip", | ||||
"sst-2": "SST-2.zip", | |||||
"sst": "SST.zip", | |||||
"rte": "RTE.zip", | "rte": "RTE.zip", | ||||
# Matching, Chinese | |||||
"cn-xnli": "XNLI.zip", | |||||
# Sequence Labeling, Chinese | |||||
"msra-ner": "MSRA_NER.zip", | "msra-ner": "MSRA_NER.zip", | ||||
"peopledaily": "peopledaily.zip", | "peopledaily": "peopledaily.zip", | ||||
"weibo-ner": "weibo_NER.zip", | "weibo-ner": "weibo_NER.zip", | ||||
# Chinese Word Segmentation | |||||
"cws-pku": 'cws_pku.zip', | "cws-pku": 'cws_pku.zip', | ||||
"cws-cityu": "cws_cityu.zip", | "cws-cityu": "cws_cityu.zip", | ||||
"cws-as": 'cws_as.zip', | "cws-as": 'cws_as.zip', | ||||
"cws-msra": 'cws_msra.zip', | "cws-msra": 'cws_msra.zip', | ||||
"chn-senti-corp":"chn_senti_corp.zip" | |||||
# Summarization, English | |||||
"ext-cnndm": "ext-cnndm.zip", | |||||
} | } | ||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | ||||
@@ -373,63 +373,6 @@ class ChnSentiCorpLoader(Loader): | |||||
""" | """ | ||||
从path中读取数据 | 从path中读取数据 | ||||
:param path: | |||||
:return: | |||||
""" | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
f.readline() | |||||
for line in f: | |||||
line = line.strip() | |||||
tab_index = line.index('\t') | |||||
if tab_index!=-1: | |||||
target = line[:tab_index] | |||||
raw_chars = line[tab_index+1:] | |||||
if raw_chars: | |||||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||||
return ds | |||||
def download(self)->str: | |||||
""" | |||||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | |||||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('chn-senti-corp') | |||||
return output_dir | |||||
class ChnSentiCorpLoader(Loader): | |||||
""" | |||||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | |||||
一个制表符及之后认为是句子 | |||||
Example:: | |||||
label raw_chars | |||||
1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~ | |||||
1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道... | |||||
0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货... | |||||
读取后的DataSet具有以下的field | |||||
.. csv-table:: | |||||
:header: "raw_chars", "target" | |||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||||
"..." | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path: str): | |||||
""" | |||||
从path中读取数据 | |||||
:param path: | :param path: | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -441,7 +384,7 @@ class ChnSentiCorpLoader(Loader): | |||||
tab_index = line.index('\t') | tab_index = line.index('\t') | ||||
if tab_index != -1: | if tab_index != -1: | ||||
target = line[:tab_index] | target = line[:tab_index] | ||||
raw_chars = line[tab_index + 1:] | |||||
raw_chars = line[tab_index+1:] | |||||
if raw_chars: | if raw_chars: | ||||
ds.append(Instance(raw_chars=raw_chars, target=target)) | ds.append(Instance(raw_chars=raw_chars, target=target)) | ||||
return ds | return ds | ||||
@@ -486,6 +429,17 @@ class THUCNewsLoader(Loader): | |||||
ds.append(Instance(raw_chars=raw_chars, target=target)) | ds.append(Instance(raw_chars=raw_chars, target=target)) | ||||
return ds | return ds | ||||
def download(self) -> str: | |||||
""" | |||||
自动下载数据,该数据取自 | |||||
http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('thuc-news') | |||||
return output_dir | |||||
class WeiboSenti100kLoader(Loader): | class WeiboSenti100kLoader(Loader): | ||||
""" | """ | ||||
@@ -518,3 +472,12 @@ class WeiboSenti100kLoader(Loader): | |||||
if raw_chars: | if raw_chars: | ||||
ds.append(Instance(raw_chars=raw_chars, target=target)) | ds.append(Instance(raw_chars=raw_chars, target=target)) | ||||
return ds | return ds | ||||
def download(self) -> str: | |||||
""" | |||||
自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ | |||||
在 https://arxiv.org/abs/1906.08101 有使用 | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('weibo-senti-100k') | |||||
return output_dir |
@@ -316,6 +316,16 @@ class CTBLoader(Loader): | |||||
dataset = self.loader._load(path) | dataset = self.loader._load(path) | ||||
return dataset | return dataset | ||||
def download(self): | |||||
""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | |||||
https://catalog.ldc.upenn.edu/LDC2013T21 | |||||
:return: | |||||
""" | |||||
raise RuntimeError("CTB cannot be downloaded automatically.") | |||||
class CNNERLoader(Loader): | class CNNERLoader(Loader): | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
@@ -13,23 +13,21 @@ from .json import JsonLoader | |||||
class CoReferenceLoader(JsonLoader): | class CoReferenceLoader(JsonLoader): | ||||
""" | """ | ||||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||||
Example:: | |||||
Example:: | |||||
{"doc_key":"bc/cctv/00/cctv_001", | |||||
"speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]", | |||||
"clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]", | |||||
"sentences":[["I","have","an","apple"],["It","is","good"]] | |||||
} | |||||
{"doc_key":"bc/cctv/00/cctv_001", | |||||
"speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]", | |||||
"clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]", | |||||
"sentences":[["I","have","an","apple"],["It","is","good"]] | |||||
} | |||||
读取预处理好的Conll2012数据。 | |||||
读取预处理好的Conll2012数据。 | |||||
""" | |||||
""" | |||||
def __init__(self, fields=None, dropna=False): | def __init__(self, fields=None, dropna=False): | ||||
super().__init__(fields, dropna) | super().__init__(fields, dropna) | ||||
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1), | |||||
# "clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | |||||
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | ||||
"sentences": Const.RAW_WORDS(3)} | "sentences": Const.RAW_WORDS(3)} | ||||
@@ -48,3 +46,13 @@ class CoReferenceLoader(JsonLoader): | |||||
ins = d | ins = d | ||||
dataset.append(Instance(**ins)) | dataset.append(Instance(**ins)) | ||||
return dataset | return dataset | ||||
def download(self): | |||||
""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | |||||
https://www.aclweb.org/anthology/W12-4501 | |||||
:return: | |||||
""" | |||||
raise RuntimeError("CoReference cannot be downloaded automatically.") |
@@ -7,7 +7,7 @@ __all__ = [ | |||||
"RTELoader", | "RTELoader", | ||||
"QuoraLoader", | "QuoraLoader", | ||||
"BQCorpusLoader", | "BQCorpusLoader", | ||||
"XNLILoader", | |||||
"CNXNLILoader", | |||||
"LCQMCLoader" | "LCQMCLoader" | ||||
] | ] | ||||
@@ -135,12 +135,12 @@ class SNLILoader(JsonLoader): | |||||
""" | """ | ||||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | ||||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||||
读取的field根据Loader初始化时传入的field决定。 | |||||
:param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl | :param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl | ||||
和snli_1.0_test.jsonl三个文件。 | 和snli_1.0_test.jsonl三个文件。 | ||||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||||
:return: 返回的 :class:`~fastNLP.io.DataBundle` | |||||
""" | """ | ||||
_paths = {} | _paths = {} | ||||
if paths is None: | if paths is None: | ||||
@@ -222,8 +222,7 @@ class QNLILoader(JsonLoader): | |||||
""" | """ | ||||
如果您的实验使用到了该数据,请引用 | 如果您的实验使用到了该数据,请引用 | ||||
.. todo:: | |||||
补充 | |||||
https://arxiv.org/pdf/1809.05053.pdf | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -276,6 +275,13 @@ class RTELoader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
如果您的实验使用到了该数据,请引用GLUE Benchmark | |||||
https://openreview.net/pdf?id=rJ4km2R5t7 | |||||
:return: | |||||
""" | |||||
return self._get_dataset_path('rte') | return self._get_dataset_path('rte') | ||||
@@ -321,10 +327,17 @@ class QuoraLoader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | |||||
https://www.kaggle.com/c/quora-question-pairs/data | |||||
:return: | |||||
""" | |||||
raise RuntimeError("Quora cannot be downloaded automatically.") | raise RuntimeError("Quora cannot be downloaded automatically.") | ||||
class XNLILoader(Loader): | |||||
class CNXNLILoader(Loader): | |||||
""" | """ | ||||
别名: | 别名: | ||||
数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理 | 数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理 | ||||
@@ -341,7 +354,7 @@ class XNLILoader(Loader): | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(XNLILoader, self).__init__() | |||||
super(CNXNLILoader, self).__init__() | |||||
def _load(self, path: str = None): | def _load(self, path: str = None): | ||||
csv_loader = CSVLoader(sep='\t') | csv_loader = CSVLoader(sep='\t') | ||||
@@ -377,6 +390,16 @@ class XNLILoader(Loader): | |||||
data_bundle = DataBundle(datasets=datasets) | data_bundle = DataBundle(datasets=datasets) | ||||
return data_bundle | return data_bundle | ||||
def download(self) -> str: | |||||
""" | |||||
自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 | |||||
在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf | |||||
https://arxiv.org/pdf/1809.05053.pdf 有使用 | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('cn-xnli') | |||||
return output_dir | |||||
class BQCorpusLoader(Loader): | class BQCorpusLoader(Loader): | ||||
""" | """ | ||||
@@ -413,6 +436,16 @@ class BQCorpusLoader(Loader): | |||||
ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | ||||
return ds | return ds | ||||
def download(self): | |||||
""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | |||||
https://github.com/ymcui/Chinese-BERT-wwm | |||||
:return: | |||||
""" | |||||
raise RuntimeError("BQCorpus cannot be downloaded automatically.") | |||||
class LCQMCLoader(Loader): | class LCQMCLoader(Loader): | ||||
""" | """ | ||||
@@ -451,16 +484,14 @@ class LCQMCLoader(Loader): | |||||
ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | ||||
return ds | return ds | ||||
''' | |||||
def download(self)->str: | |||||
def download(self): | |||||
""" | """ | ||||
自动下载数据,该数据取自论文 LCQMC: A Large-scale Chinese Question Matching Corpus. | |||||
InProceedings of the 27thInternational Conference on Computational Linguistics. 1952–1962. | |||||
由于版权限制,不能提供自动下载功能。可参考 | |||||
https://github.com/ymcui/Chinese-BERT-wwm | |||||
:return: | :return: | ||||
""" | """ | ||||
output_dir = self._get_dataset_path('chn-senti-corp') | |||||
return output_dir | |||||
''' | |||||
raise RuntimeError("LCQMC cannot be downloaded automatically.") | |||||
@@ -0,0 +1,63 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"ExtCNNDMLoader" | |||||
] | |||||
import os | |||||
from typing import Union, Dict | |||||
from ..data_bundle import DataBundle | |||||
from ..utils import check_loader_paths | |||||
from .json import JsonLoader | |||||
class ExtCNNDMLoader(JsonLoader): | |||||
""" | |||||
读取之后的DataSet中的field情况为 | |||||
.. csv-table:: | |||||
:header: "text", "summary", "label", "publication" | |||||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||||
["..."], ["..."], [], "cnndm" | |||||
""" | |||||
def __init__(self, fields=None): | |||||
fields = fields or {"text": None, "summary": None, "label": None, "publication": None} | |||||
super(ExtCNNDMLoader, self).__init__(fields=fields) | |||||
def load(self, paths: Union[str, Dict[str, str]] = None): | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||||
读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 | |||||
:param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl | |||||
test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe` | |||||
当中需要用到)。 | |||||
:return: 返回 :class:`~fastNLP.io.DataBundle` | |||||
""" | |||||
if paths is None: | |||||
paths = self.download() | |||||
paths = check_loader_paths(paths) | |||||
if ('train' in paths) and ('test' not in paths): | |||||
paths['test'] = paths['train'] | |||||
paths.pop('train') | |||||
datasets = {name: self._load(path) for name, path in paths.items()} | |||||
data_bundle = DataBundle(datasets=datasets) | |||||
return data_bundle | |||||
def download(self): | |||||
""" | |||||
如果你使用了这个数据,请引用 | |||||
https://arxiv.org/pdf/1506.03340.pdf | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('ext-cnndm') | |||||
return output_dir |
@@ -18,6 +18,8 @@ __all__ = [ | |||||
"SST2Pipe", | "SST2Pipe", | ||||
"IMDBPipe", | "IMDBPipe", | ||||
"ChnSentiCorpPipe", | "ChnSentiCorpPipe", | ||||
"THUCNewsPipe", | |||||
"WeiboSenti100kPipe", | |||||
"Conll2003NERPipe", | "Conll2003NERPipe", | ||||
"OntoNotesNERPipe", | "OntoNotesNERPipe", | ||||
@@ -42,7 +44,7 @@ __all__ = [ | |||||
"CoReferencePipe" | "CoReferencePipe" | ||||
] | ] | ||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | |||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe | |||||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | ||||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | ||||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | ||||
@@ -97,11 +97,22 @@ class YelpFullPipe(_CLSPipe): | |||||
处理YelpFull的数据, 处理之后DataSet中的内容如下 | 处理YelpFull的数据, 处理之后DataSet中的内容如下 | ||||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | ||||
:header: "raw_words", "words", "target", "seq_len" | |||||
:header: "raw_words", "target", "words", "seq_len" | |||||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||||
"...", ., "[...]", . | |||||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||||
"...", "[...]", ., . | |||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_words | target | words | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | False | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
@@ -193,11 +204,22 @@ class YelpPolarityPipe(_CLSPipe): | |||||
处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | ||||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | ||||
:header: "raw_words", "words", "target", "seq_len" | |||||
:header: "raw_words", "target", "words", "seq_len" | |||||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||||
"...", "[...]", ., . | |||||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||||
"...", ., "[...]", . | |||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_words | target | words | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | False | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
@@ -211,6 +233,19 @@ class YelpPolarityPipe(_CLSPipe): | |||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
""" | |||||
传入的DataSet应该具备如下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_words", "target" | |||||
"I got 'new' tires from them and... ", "1" | |||||
"Don't waste your time. We had two...", "1" | |||||
"...", "..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
# 复制一列words | # 复制一列words | ||||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | data_bundle = _add_words_field(data_bundle, lower=self.lower) | ||||
@@ -244,9 +279,20 @@ class SSTPipe(_CLSPipe): | |||||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | ||||
:header: "raw_words", "words", "target", "seq_len" | :header: "raw_words", "words", "target", "seq_len" | ||||
"It 's a ...", "[4, 2, 10, ...]", 0, 16 | |||||
"Offers that ...", "[20, 40, ...]", 1, 18 | |||||
"...", "[...]", ., . | |||||
"It 's a lovely film with lovely perfor...", 1, "[187, 6, 5, 132, 120, 70, 132, 188, 25...", 13 | |||||
"No one goes unindicted here , which is...", 0, "[191, 126, 192, 193, 194, 4, 195, 17, ...", 13 | |||||
"...", ., "[...]", . | |||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_words | target | words | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | False | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
@@ -278,11 +324,11 @@ class SSTPipe(_CLSPipe): | |||||
""" | """ | ||||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | ||||
.. csv-table:: | |||||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||||
:header: "raw_words" | :header: "raw_words" | ||||
"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." | |||||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | |||||
"(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." | |||||
"(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." | |||||
"..." | "..." | ||||
:param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 | :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 | ||||
@@ -335,12 +381,23 @@ class SST2Pipe(_CLSPipe): | |||||
加载SST2的数据, 处理完成之后DataSet将拥有以下的field | 加载SST2的数据, 处理完成之后DataSet将拥有以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_words", "words", "target", "seq_len" | |||||
:header: "raw_words", "target", "words", "seq_len" | |||||
"it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43 | |||||
"unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21 | |||||
"it 's a charming and often affecting j... ", 1, "[19, 9, 6, 111, 5, 112, 113, 114, 3]", 9 | |||||
"unflinchingly bleak and desperate", 0, "[115, 116, 5, 117]", 4 | |||||
"...", "...", ., . | "...", "...", ., . | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_words | target | words | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | False | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer='spacy'): | def __init__(self, lower=False, tokenizer='spacy'): | ||||
@@ -357,11 +414,11 @@ class SST2Pipe(_CLSPipe): | |||||
可以处理的DataSet应该具备如下的结构 | 可以处理的DataSet应该具备如下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_words", "target" | |||||
:header: "raw_words", "target" | |||||
"it 's a charming and... ", 1 | |||||
"unflinchingly bleak and...", 1 | |||||
"...", "..." | |||||
"it 's a charming and often affecting...", "1" | |||||
"unflinchingly bleak and...", "0" | |||||
"..." | |||||
:param data_bundle: | :param data_bundle: | ||||
:return: | :return: | ||||
@@ -420,15 +477,26 @@ class IMDBPipe(_CLSPipe): | |||||
经过本Pipe处理后DataSet将如下 | 经过本Pipe处理后DataSet将如下 | ||||
.. csv-table:: 输出DataSet的field | .. csv-table:: 输出DataSet的field | ||||
:header: "raw_words", "words", "target", "seq_len" | |||||
:header: "raw_words", "target", "words", "seq_len" | |||||
"Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20 | |||||
"Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31 | |||||
"...", "[...]", ., . | |||||
"Bromwell High is a cartoon ... ", 0, "[3, 5, 6, 9, ...]", 20 | |||||
"Story of a man who has ...", 1, "[20, 43, 9, 10, ...]", 31 | |||||
"...", ., "[...]", . | |||||
其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; | 其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; | ||||
words列被设置为input; target列被设置为target。 | words列被设置为input; target列被设置为target。 | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_words | target | words | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | False | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | ||||
@@ -493,13 +561,23 @@ class ChnSentiCorpPipe(Pipe): | |||||
处理之后的DataSet有以下的结构 | 处理之后的DataSet有以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_chars", "chars", "target", "seq_len" | |||||
:header: "raw_chars", "target", "chars", "seq_len" | |||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31 | |||||
"<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25 | |||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", 1, "[2, 3, 4, 5, ...]", 31 | |||||
"<荐书> 推荐所有喜欢<红楼>...", 1, "[10, 21, ....]", 25 | |||||
"..." | "..." | ||||
其中chars, seq_len是input,target是target | 其中chars, seq_len是input,target是target | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_chars | target | chars | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def __init__(self, bigrams=False, trigrams=False): | def __init__(self, bigrams=False, trigrams=False): | ||||
@@ -590,12 +668,22 @@ class THUCNewsPipe(_CLSPipe): | |||||
处理之后的DataSet有以下的结构 | 处理之后的DataSet有以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_chars", "chars", "target", "seq_len" | |||||
:header: "raw_chars", "target", "chars", "seq_len" | |||||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", "[409, 1197, 2146, 213, ...]", 0, 746 | |||||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", 0, "[409, 1197, 2146, 213, ...]", 746 | |||||
"..." | "..." | ||||
其中chars, seq_len是input,target是target | 其中chars, seq_len是input,target是target | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_chars | target | chars | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | ||||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | ||||
@@ -691,12 +779,22 @@ class WeiboSenti100kPipe(_CLSPipe): | |||||
处理之后的DataSet有以下的结构 | 处理之后的DataSet有以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_chars", "chars", "target", "seq_len" | |||||
:header: "raw_chars", "target", "chars", "seq_len" | |||||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "[0, 690, 18, ...]", 0, 56 | |||||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", 0, "[0, 690, 18, ...]", 56 | |||||
"..." | "..." | ||||
其中chars, seq_len是input,target是target | 其中chars, seq_len是input,target是target | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_chars | target | chars | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | False | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | ||||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | ||||
@@ -87,15 +87,26 @@ class Conll2003NERPipe(_NERPipe): | |||||
经过该Pipe过后,DataSet中的内容如下所示 | 经过该Pipe过后,DataSet中的内容如下所示 | ||||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | ||||
:header: "raw_words", "words", "target", "seq_len" | |||||
:header: "raw_words", "target", "words", "seq_len" | |||||
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4,...]", 6 | |||||
"[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[3, 4,...]", "[4, 5, 6,...]", 6 | |||||
"[...]", "[...]", "[...]", . | "[...]", "[...]", "[...]", . | ||||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_words | target | words | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | True | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
@@ -112,17 +123,28 @@ class Conll2003NERPipe(_NERPipe): | |||||
class Conll2003Pipe(Pipe): | class Conll2003Pipe(Pipe): | ||||
r""" | |||||
""" | |||||
经过该Pipe后,DataSet中的内容如下 | 经过该Pipe后,DataSet中的内容如下 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_words" , "words", "pos", "chunk", "ner", "seq_len" | |||||
:header: "raw_words" , "pos", "chunk", "ner", "words", "seq_len" | |||||
"[Nadim, Ladki]", "[2, 3]", "[0, 0]", "[1, 2]", "[1, 2]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", 6 | |||||
"[Nadim, Ladki]", "[0, 0]", "[1, 2]", "[1, 2]", "[2, 3]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", "[4, 5, 6,...]", 6 | |||||
"[...]", "[...]", "[...]", "[...]", "[...]", . | "[...]", "[...]", "[...]", "[...]", "[...]", . | ||||
其中words, seq_len是input; pos, chunk, ner, seq_len是target | 其中words, seq_len是input; pos, chunk, ner, seq_len是target | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||||
| field_names | raw_words | pos | chunk | ner | words | seq_len | | |||||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||||
| is_input | False | False | False | False | True | True | | |||||
| is_target | False | True | True | True | False | True | | |||||
| ignore_type | | False | False | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | 0 | 0 | | |||||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||||
""" | """ | ||||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | ||||
@@ -202,15 +224,26 @@ class OntoNotesNERPipe(_NERPipe): | |||||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_words", "words", "target", "seq_len" | |||||
:header: "raw_words", "target", "words", "seq_len" | |||||
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4]", 6 | |||||
"[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[3, 4]", "[4, 5, 6,...]", 6 | |||||
"[...]", "[...]", "[...]", . | "[...]", "[...]", "[...]", . | ||||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_words | target | words | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | True | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def process_from_file(self, paths): | def process_from_file(self, paths): | ||||
@@ -306,15 +339,26 @@ class MsraNERPipe(_CNNERPipe): | |||||
处理MSRA-NER的数据,处理之后的DataSet的field情况为 | 处理MSRA-NER的数据,处理之后的DataSet的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_chars", "chars", "target", "seq_len" | |||||
:header: "raw_chars", "target", "chars", "seq_len" | |||||
"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 | |||||
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 | |||||
"[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 | |||||
"[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 | |||||
"[...]", "[...]", "[...]", . | "[...]", "[...]", "[...]", . | ||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_chars | target | chars | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | True | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
@@ -327,14 +371,26 @@ class PeopleDailyPipe(_CNNERPipe): | |||||
处理people daily的ner的数据,处理之后的DataSet的field情况为 | 处理people daily的ner的数据,处理之后的DataSet的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_chars", "chars", "target", "seq_len" | |||||
:header: "raw_chars", "target", "chars", "seq_len" | |||||
"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 | |||||
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 | |||||
"[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 | |||||
"[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 | |||||
"[...]", "[...]", "[...]", . | "[...]", "[...]", "[...]", . | ||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_chars | target | chars | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | True | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
@@ -349,13 +405,24 @@ class WeiboNERPipe(_CNNERPipe): | |||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_chars", "chars", "target", "seq_len" | :header: "raw_chars", "chars", "target", "seq_len" | ||||
"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 | |||||
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 | |||||
"['老', '百', '姓']", "[4, 3, 3]", "[38, 39, 40]", 3 | |||||
"['心']", "[0]", "[41]", 1 | |||||
"[...]", "[...]", "[...]", . | "[...]", "[...]", "[...]", . | ||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_chars | target | chars | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | True | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
@@ -18,9 +18,29 @@ from ...core.const import Const | |||||
class CoReferencePipe(Pipe): | class CoReferencePipe(Pipe): | ||||
""" | """ | ||||
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | ||||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||||
.. csv-table:: | |||||
:header: "words1", "words2","words3","words4","chars","seq_len","target" | |||||
"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" | |||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| field_names | raw_chars | target | chars | seq_len | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | True | False | True | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+--------+-------+---------+ | |||||
""" | """ | ||||
def __init__(self,config): | |||||
def __init__(self, config): | |||||
super().__init__() | super().__init__() | ||||
self.config = config | self.config = config | ||||
@@ -35,14 +55,6 @@ class CoReferencePipe(Pipe): | |||||
"bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | "bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | ||||
"[...]", "[...]","[...]","[...]" | "[...]", "[...]","[...]","[...]" | ||||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||||
.. csv-table:: | |||||
:header: "words1", "words2","words3","words4","chars","seq_len","target" | |||||
"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||||
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" | |||||
:param data_bundle: | :param data_bundle: | ||||
:return: | :return: | ||||
@@ -138,13 +138,22 @@ class CWSPipe(Pipe): | |||||
对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | 对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_words", "chars", "target", "bigrams", "trigrams", "seq_len" | |||||
:header: "raw_words", "chars", "target", "seq_len" | |||||
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", "[10, 4, 1,...]","[6, 4, 1,...]", 13 | |||||
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", "[11, 12, ...]","[3, 9, ...]", 20 | |||||
"...", "[...]","[...]", "[...]","[...]", . | |||||
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", 13 | |||||
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", 20 | |||||
"...", "[...]","[...]", . | |||||
其中bigrams仅当bigrams列为True的时候存在 | |||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+-----------+-------+--------+---------+ | |||||
| field_names | raw_words | chars | target | seq_len | | |||||
+-------------+-----------+-------+--------+---------+ | |||||
| is_input | False | True | True | True | | |||||
| is_target | False | False | True | True | | |||||
| ignore_type | | False | False | False | | |||||
| pad_value | | 0 | 0 | 0 | | |||||
+-------------+-----------+-------+--------+---------+ | |||||
""" | """ | ||||
@@ -7,7 +7,7 @@ __all__ = [ | |||||
"QuoraBertPipe", | "QuoraBertPipe", | ||||
"QNLIBertPipe", | "QNLIBertPipe", | ||||
"MNLIBertPipe", | "MNLIBertPipe", | ||||
"XNLIBertPipe", | |||||
"CNXNLIBertPipe", | |||||
"BQCorpusBertPipe", | "BQCorpusBertPipe", | ||||
"LCQMCBertPipe", | "LCQMCBertPipe", | ||||
"MatchingPipe", | "MatchingPipe", | ||||
@@ -16,7 +16,7 @@ __all__ = [ | |||||
"QuoraPipe", | "QuoraPipe", | ||||
"QNLIPipe", | "QNLIPipe", | ||||
"MNLIPipe", | "MNLIPipe", | ||||
"XNLIPipe", | |||||
"CNXNLIPipe", | |||||
"BQCorpusPipe", | "BQCorpusPipe", | ||||
"LCQMCPipe", | "LCQMCPipe", | ||||
] | ] | ||||
@@ -25,7 +25,7 @@ import warnings | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .utils import get_tokenizer | from .utils import get_tokenizer | ||||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, XNLILoader, LCQMCLoader | |||||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, LCQMCLoader | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...core._logger import logger | from ...core._logger import logger | ||||
@@ -37,16 +37,27 @@ class MatchingBertPipe(Pipe): | |||||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | Matching任务的Bert pipe,输出的DataSet将包含以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_words1", "raw_words2", "words", "target", "seq_len" | |||||
:header: "raw_words1", "raw_words2", "target", "words", "seq_len" | |||||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10 | |||||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5 | |||||
"...", "...", "[...]", ., . | |||||
"The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", 10 | |||||
"This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", 5 | |||||
"...", "...", ., "[...]", . | |||||
words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 | words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 | ||||
words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, | words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, | ||||
如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). | 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+------------+------------+--------+-------+---------+ | |||||
| field_names | raw_words1 | raw_words2 | target | words | seq_len | | |||||
+-------------+------------+------------+--------+-------+---------+ | |||||
| is_input | False | False | False | True | True | | |||||
| is_target | False | False | True | False | False | | |||||
| ignore_type | | | False | False | False | | |||||
| pad_value | | | 0 | 0 | 0 | | |||||
+-------------+------------+------------+--------+-------+---------+ | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | def __init__(self, lower=False, tokenizer: str = 'raw'): | ||||
@@ -75,6 +86,18 @@ class MatchingBertPipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
""" | |||||
输入的data_bundle中的dataset需要具有以下结构: | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "target" | |||||
"Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" | |||||
"...","..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
for dataset in data_bundle.datasets.values(): | for dataset in data_bundle.datasets.values(): | ||||
if dataset.has_field(Const.TARGET): | if dataset.has_field(Const.TARGET): | ||||
dataset.drop(lambda x: x[Const.TARGET] == '-') | dataset.drop(lambda x: x[Const.TARGET] == '-') | ||||
@@ -178,15 +201,27 @@ class MatchingPipe(Pipe): | |||||
Matching任务的Pipe。输出的DataSet将包含以下的field | Matching任务的Pipe。输出的DataSet将包含以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
:header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2" | |||||
:header: "raw_words1", "raw_words2", "target", "words1", "words2", "seq_len1", "seq_len2" | |||||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13 | |||||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 | |||||
"...", "...", "[...]", "[...]", ., ., . | |||||
"The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", "[10, 20, 6]", 10, 13 | |||||
"This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", "[2, 7, ...]", 6, 7 | |||||
"...", "...", ., "[...]", "[...]", ., . | |||||
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target | words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target | ||||
和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 | 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 | ||||
的形参名进行传参)。 | 的形参名进行传参)。 | ||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||||
| field_names | raw_words1 | raw_words2 | target | words1 | words2 | seq_len1 | seq_len2 | | |||||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||||
| is_input | False | False | False | True | True | True | True | | |||||
| is_target | False | False | True | False | False | False | False | | |||||
| ignore_type | | | False | False | False | False | False | | |||||
| pad_value | | | 0 | 0 | 0 | 0 | 0 | | |||||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | def __init__(self, lower=False, tokenizer: str = 'raw'): | ||||
@@ -319,10 +354,10 @@ class LCQMCPipe(MatchingPipe): | |||||
return data_bundle | return data_bundle | ||||
class XNLIPipe(MatchingPipe): | |||||
def process_from_file(self, paths = None): | |||||
data_bundle = XNLILoader().load(paths) | |||||
data_bundle = GranularizePipe(task = 'XNLI').process(data_bundle) | |||||
class CNXNLIPipe(MatchingPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = CNXNLILoader().load(paths) | |||||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | |||||
data_bundle = RenamePipe().process(data_bundle) #使中文数据的field | data_bundle = RenamePipe().process(data_bundle) #使中文数据的field | ||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
data_bundle = RenamePipe().process(data_bundle) | data_bundle = RenamePipe().process(data_bundle) | ||||
@@ -438,9 +473,9 @@ class BQCorpusBertPipe(MatchingBertPipe): | |||||
return data_bundle | return data_bundle | ||||
class XNLIBertPipe(MatchingBertPipe): | |||||
class CNXNLIBertPipe(MatchingBertPipe): | |||||
def process_from_file(self, paths = None): | def process_from_file(self, paths = None): | ||||
data_bundle = XNLILoader().load(paths) | |||||
data_bundle = CNXNLILoader().load(paths) | |||||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | ||||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | ||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
@@ -1,15 +1,14 @@ | |||||
"""undocumented""" | """undocumented""" | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||||
from ..loader.json import JsonLoader | |||||
from .utils import _drop_empty_instance | |||||
from ..loader.summarization import ExtCNNDMLoader | |||||
from ..data_bundle import DataBundle | from ..data_bundle import DataBundle | ||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...core._logger import logger | |||||
WORD_PAD = "[PAD]" | WORD_PAD = "[PAD]" | ||||
@@ -18,7 +17,6 @@ DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | TAG_UNK = "X" | ||||
class ExtCNNDMPipe(Pipe): | class ExtCNNDMPipe(Pipe): | ||||
""" | """ | ||||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | 对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | ||||
@@ -27,13 +25,13 @@ class ExtCNNDMPipe(Pipe): | |||||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | ||||
""" | """ | ||||
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False): | |||||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||||
""" | """ | ||||
:param vocab_size: int, 词表大小 | :param vocab_size: int, 词表大小 | ||||
:param vocab_path: str, 外部词表路径 | |||||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | :param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | ||||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | :param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | ||||
:param vocab_path: str, 外部词表路径 | |||||
:param domain: bool, 是否需要建立domain词表 | :param domain: bool, 是否需要建立domain词表 | ||||
""" | """ | ||||
self.vocab_size = vocab_size | self.vocab_size = vocab_size | ||||
@@ -42,8 +40,7 @@ class ExtCNNDMPipe(Pipe): | |||||
self.doc_max_timesteps = doc_max_timesteps | self.doc_max_timesteps = doc_max_timesteps | ||||
self.domain = domain | self.domain = domain | ||||
def process(self, db: DataBundle): | |||||
def process(self, data_bundle: DataBundle): | |||||
""" | """ | ||||
传入的DataSet应该具备如下的结构 | 传入的DataSet应该具备如下的结构 | ||||
@@ -64,24 +61,28 @@ class ExtCNNDMPipe(Pipe): | |||||
[[""],...,[""]], [[],...,[]], [], [] | [[""],...,[""]], [[],...,[]], [], [] | ||||
""" | """ | ||||
db.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||||
if self.vocab_path is None: | |||||
error_msg = 'vocab file is not defined!' | |||||
logger.error(error_msg) | |||||
raise RuntimeError(error_msg) | |||||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||||
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | ||||
# pad document | # pad document | ||||
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||||
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||||
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||||
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||||
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||||
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||||
db = _drop_empty_instance(db, "label") | |||||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||||
# set input and target | # set input and target | ||||
db.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
db.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | # print("[INFO] Load existing vocab from %s!" % self.vocab_path) | ||||
word_list = [] | word_list = [] | ||||
@@ -96,47 +97,52 @@ class ExtCNNDMPipe(Pipe): | |||||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | ||||
vocabs.add_word_lst(word_list) | vocabs.add_word_lst(word_list) | ||||
vocabs.build_vocab() | vocabs.build_vocab() | ||||
db.set_vocab(vocabs, "vocab") | |||||
data_bundle.set_vocab(vocabs, "vocab") | |||||
if self.domain == True: | |||||
if self.domain is True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | ||||
domaindict.from_dataset(db.get_dataset("train"), field_name="publication") | |||||
db.set_vocab(domaindict, "domain") | |||||
return db | |||||
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") | |||||
data_bundle.set_vocab(domaindict, "domain") | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | """ | ||||
:param paths: dict or string | |||||
:return: DataBundle | |||||
""" | |||||
db = DataBundle() | |||||
if isinstance(paths, dict): | |||||
for key, value in paths.items(): | |||||
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key) | |||||
else: | |||||
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test') | |||||
self.process(db) | |||||
:param paths: dict or string | |||||
:return: DataBundle | |||||
""" | |||||
loader = ExtCNNDMLoader() | |||||
if self.vocab_path is None: | |||||
if paths is None: | |||||
paths = loader.download() | |||||
if not os.path.isdir(paths): | |||||
error_msg = 'vocab file is not defined!' | |||||
logger.error(error_msg) | |||||
raise RuntimeError(error_msg) | |||||
self.vocab_path = os.path.join(paths, 'vocab') | |||||
db = loader.load(paths=paths) | |||||
db = self.process(db) | |||||
for ds in db.datasets.values(): | for ds in db.datasets.values(): | ||||
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | ||||
return db | return db | ||||
def _lower_text(text_list): | def _lower_text(text_list): | ||||
return [text.lower() for text in text_list] | return [text.lower() for text in text_list] | ||||
def _split_list(text_list): | def _split_list(text_list): | ||||
return [text.split() for text in text_list] | return [text.split() for text in text_list] | ||||
def _convert_label(label, sent_len): | def _convert_label(label, sent_len): | ||||
np_label = np.zeros(sent_len, dtype=int) | np_label = np.zeros(sent_len, dtype=int) | ||||
if label != []: | if label != []: | ||||
np_label[np.array(label)] = 1 | np_label[np.array(label)] = 1 | ||||
return np_label.tolist() | return np_label.tolist() | ||||
def _pad_sent(text_wd, sent_max_len): | def _pad_sent(text_wd, sent_max_len): | ||||
pad_text_wd = [] | pad_text_wd = [] | ||||
for sent_wd in text_wd: | for sent_wd in text_wd: | ||||
@@ -148,6 +154,7 @@ def _pad_sent(text_wd, sent_max_len): | |||||
pad_text_wd.append(sent_wd) | pad_text_wd.append(sent_wd) | ||||
return pad_text_wd | return pad_text_wd | ||||
def _token_mask(text_wd, sent_max_len): | def _token_mask(text_wd, sent_max_len): | ||||
token_mask_list = [] | token_mask_list = [] | ||||
for sent_wd in text_wd: | for sent_wd in text_wd: | ||||
@@ -159,6 +166,7 @@ def _token_mask(text_wd, sent_max_len): | |||||
token_mask_list.append(mask) | token_mask_list.append(mask) | ||||
return token_mask_list | return token_mask_list | ||||
def _pad_label(label, doc_max_timesteps): | def _pad_label(label, doc_max_timesteps): | ||||
text_len = len(label) | text_len = len(label) | ||||
if text_len < doc_max_timesteps: | if text_len < doc_max_timesteps: | ||||
@@ -167,6 +175,7 @@ def _pad_label(label, doc_max_timesteps): | |||||
pad_label = label[:doc_max_timesteps] | pad_label = label[:doc_max_timesteps] | ||||
return pad_label | return pad_label | ||||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | ||||
text_len = len(text_wd) | text_len = len(text_wd) | ||||
if text_len < doc_max_timesteps: | if text_len < doc_max_timesteps: | ||||
@@ -176,6 +185,7 @@ def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||||
pad_text = text_wd[:doc_max_timesteps] | pad_text = text_wd[:doc_max_timesteps] | ||||
return pad_text | return pad_text | ||||
def _sent_mask(text_wd, doc_max_timesteps): | def _sent_mask(text_wd, doc_max_timesteps): | ||||
text_len = len(text_wd) | text_len = len(text_wd) | ||||
if text_len < doc_max_timesteps: | if text_len < doc_max_timesteps: | ||||
@@ -22,6 +22,9 @@ class BaseModel(torch.nn.Module): | |||||
class NaiveClassifier(BaseModel): | class NaiveClassifier(BaseModel): | ||||
""" | |||||
一个简单的分类器例子,可用于各种测试 | |||||
""" | |||||
def __init__(self, in_feature_dim, out_feature_dim): | def __init__(self, in_feature_dim, out_feature_dim): | ||||
super(NaiveClassifier, self).__init__() | super(NaiveClassifier, self).__init__() | ||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | ||||
@@ -1,5 +1,6 @@ | |||||
"""undocumented | """undocumented | ||||
Variational RNN 的 Pytorch 实现 | |||||
Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | |||||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -227,6 +228,7 @@ class VarRNNBase(nn.Module): | |||||
class VarLSTM(VarRNNBase): | class VarLSTM(VarRNNBase): | ||||
""" | """ | ||||
Variational Dropout LSTM. | Variational Dropout LSTM. | ||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | """ | ||||
@@ -253,7 +255,8 @@ class VarLSTM(VarRNNBase): | |||||
class VarRNN(VarRNNBase): | class VarRNN(VarRNNBase): | ||||
""" | """ | ||||
Variational Dropout RNN. | Variational Dropout RNN. | ||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
@@ -279,7 +282,8 @@ class VarRNN(VarRNNBase): | |||||
class VarGRU(VarRNNBase): | class VarGRU(VarRNNBase): | ||||
""" | """ | ||||
Variational Dropout GRU. | Variational Dropout GRU. | ||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
@@ -1,39 +1,35 @@ | |||||
import os | |||||
import tempfile | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import os | |||||
import shutil | |||||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||||
LRFinder, TensorboardCallback | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import BCELoss | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import BCELoss | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import SGD | from fastNLP import SGD | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from fastNLP.models.base_model import NaiveClassifier | |||||
from fastNLP.core.callback import EarlyStopError | |||||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||||
LRFinder, TensorboardCallback | |||||
from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback | from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback | ||||
from fastNLP.core.callback import WarmupCallback | from fastNLP.core.callback import WarmupCallback | ||||
import tempfile | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
def prepare_env(): | def prepare_env(): | ||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x") | data_set.set_input("x") | ||||
data_set.set_target("y") | data_set.set_target("y") | ||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
@@ -43,11 +39,11 @@ def prepare_env(): | |||||
class TestCallback(unittest.TestCase): | class TestCallback(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.tempdir = tempfile.mkdtemp() | self.tempdir = tempfile.mkdtemp() | ||||
def tearDown(self): | def tearDown(self): | ||||
pass | pass | ||||
# shutil.rmtree(self.tempdir) | # shutil.rmtree(self.tempdir) | ||||
def test_gradient_clip(self): | def test_gradient_clip(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | ||||
@@ -100,7 +96,7 @@ class TestCallback(unittest.TestCase): | |||||
path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time)) | path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time)) | ||||
if os.path.exists(path): | if os.path.exists(path): | ||||
shutil.rmtree(path) | shutil.rmtree(path) | ||||
def test_readonly_property(self): | def test_readonly_property(self): | ||||
from fastNLP.core.callback import Callback | from fastNLP.core.callback import Callback | ||||
passed_epochs = [] | passed_epochs = [] | ||||
@@ -123,19 +119,19 @@ class TestCallback(unittest.TestCase): | |||||
check_code_level=2) | check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
assert passed_epochs == list(range(1, total_epochs + 1)) | assert passed_epochs == list(range(1, total_epochs + 1)) | ||||
def test_evaluate_callback(self): | def test_evaluate_callback(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
from fastNLP import Tester | from fastNLP import Tester | ||||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | ||||
evaluate_callback = EvaluateCallback(data_set, tester) | evaluate_callback = EvaluateCallback(data_set, tester) | ||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | ||||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | ||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | ||||
callbacks=evaluate_callback, check_code_level=2) | callbacks=evaluate_callback, check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
def test_fitlog_callback(self): | def test_fitlog_callback(self): | ||||
import fitlog | import fitlog | ||||
fitlog.set_log_dir(self.tempdir) | fitlog.set_log_dir(self.tempdir) | ||||
@@ -143,13 +139,13 @@ class TestCallback(unittest.TestCase): | |||||
from fastNLP import Tester | from fastNLP import Tester | ||||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | ||||
fitlog_callback = FitlogCallback(data_set, tester) | fitlog_callback = FitlogCallback(data_set, tester) | ||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | ||||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | ||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | ||||
callbacks=fitlog_callback, check_code_level=2) | callbacks=fitlog_callback, check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
def test_save_model_callback(self): | def test_save_model_callback(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
top = 3 | top = 3 | ||||
@@ -159,10 +155,10 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | ||||
callbacks=save_model_callback, check_code_level=2) | callbacks=save_model_callback, check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
timestamp = os.listdir(self.tempdir)[0] | timestamp = os.listdir(self.tempdir)[0] | ||||
self.assertEqual(len(os.listdir(os.path.join(self.tempdir, timestamp))), top) | self.assertEqual(len(os.listdir(os.path.join(self.tempdir, timestamp))), top) | ||||
def test_warmup_callback(self): | def test_warmup_callback(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
warmup_callback = WarmupCallback() | warmup_callback = WarmupCallback() | ||||
@@ -171,3 +167,50 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | ||||
callbacks=warmup_callback, check_code_level=2) | callbacks=warmup_callback, check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
def test_early_stop_callback(self): | |||||
""" | |||||
需要观察是否真的 EarlyStop | |||||
""" | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=2, n_epochs=10, print_every=5, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||||
callbacks=EarlyStopCallback(1), check_code_level=2) | |||||
trainer.train() | |||||
def test_control_C(): | |||||
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 | |||||
from fastNLP import ControlC, Callback | |||||
import time | |||||
line1 = "\n\n\n\n\n*************************" | |||||
line2 = "*************************\n\n\n\n\n" | |||||
class Wait(Callback): | |||||
def on_epoch_end(self): | |||||
time.sleep(5) | |||||
data_set, model = prepare_env() | |||||
print(line1 + "Test starts!" + line2) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=20, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||||
callbacks=[Wait(), ControlC(False)], check_code_level=2) | |||||
trainer.train() | |||||
print(line1 + "Program goes on ..." + line2) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=20, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||||
callbacks=[Wait(), ControlC(True)], check_code_level=2) | |||||
trainer.train() | |||||
print(line1 + "Test failed!" + line2) | |||||
if __name__ == "__main__": | |||||
test_control_C() |
@@ -1,33 +1,36 @@ | |||||
import os | |||||
import shutil | |||||
import subprocess | |||||
import unittest | import unittest | ||||
from argparse import ArgumentParser | |||||
import numpy as np | import numpy as np | ||||
import torch.cuda | import torch.cuda | ||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import CrossEntropyLoss, BCELoss | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import CrossEntropyLoss, BCELoss | |||||
from fastNLP import SGD | from fastNLP import SGD | ||||
from fastNLP.core.callback import EchoCallback | |||||
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | ||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
import shutil | |||||
import os | |||||
import subprocess | |||||
from argparse import ArgumentParser | |||||
from fastNLP.core.callback import EchoCallback | |||||
from fastNLP import AccuracyMetric | |||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
mean = np.array([3, 3]) | mean = np.array([3, 3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] + | data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] + | ||||
[Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B]) | [Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B]) | ||||
return data_set | return data_set | ||||
def prepare_fake_dataset2(*args, size=100): | def prepare_fake_dataset2(*args, size=100): | ||||
ys = np.random.randint(4, size=100, dtype=np.int64) | ys = np.random.randint(4, size=100, dtype=np.int64) | ||||
data = {'y': ys} | data = {'y': ys} | ||||
@@ -35,32 +38,35 @@ def prepare_fake_dataset2(*args, size=100): | |||||
data[arg] = np.random.randn(size, 5) | data[arg] = np.random.randn(size, 5) | ||||
return DataSet(data=data) | return DataSet(data=data) | ||||
def set_rng_seed(seed): | def set_rng_seed(seed): | ||||
np.random.seed(seed) | np.random.seed(seed) | ||||
def prepare_env(): | def prepare_env(): | ||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
mean = np.array([3, 3]) | mean = np.array([3, 3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | ||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | ||||
return data_set | return data_set | ||||
data_set = prepare_fake_dataset() | data_set = prepare_fake_dataset() | ||||
data_set.set_input("x") | data_set.set_input("x") | ||||
data_set.set_target("y") | data_set.set_target("y") | ||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
return data_set, model | return data_set, model | ||||
class TestDistTrainer(unittest.TestCase): | class TestDistTrainer(unittest.TestCase): | ||||
save_path = './save_cp' | save_path = './save_cp' | ||||
def run1(self): | def run1(self): | ||||
# test distributed training | # test distributed training | ||||
print('local rank', get_local_rank()) | print('local rank', get_local_rank()) | ||||
@@ -68,9 +74,9 @@ class TestDistTrainer(unittest.TestCase): | |||||
data_set = prepare_fake_dataset() | data_set = prepare_fake_dataset() | ||||
data_set.set_input("x", flag=True) | data_set.set_input("x", flag=True) | ||||
data_set.set_target("y", flag=True) | data_set.set_target("y", flag=True) | ||||
model = NaiveClassifier(2, 2) | model = NaiveClassifier(2, 2) | ||||
trainer = DistTrainer( | trainer = DistTrainer( | ||||
model=model, train_data=data_set, optimizer=SGD(lr=0.1), | model=model, train_data=data_set, optimizer=SGD(lr=0.1), | ||||
loss=CrossEntropyLoss(pred="predict", target="y"), | loss=CrossEntropyLoss(pred="predict", target="y"), | ||||
@@ -82,7 +88,7 @@ class TestDistTrainer(unittest.TestCase): | |||||
""" | """ | ||||
if trainer.is_master and os.path.exists(self.save_path): | if trainer.is_master and os.path.exists(self.save_path): | ||||
shutil.rmtree(self.save_path) | shutil.rmtree(self.save_path) | ||||
def run2(self): | def run2(self): | ||||
# test fp16 with distributed training | # test fp16 with distributed training | ||||
print('local rank', get_local_rank()) | print('local rank', get_local_rank()) | ||||
@@ -90,9 +96,9 @@ class TestDistTrainer(unittest.TestCase): | |||||
data_set = prepare_fake_dataset() | data_set = prepare_fake_dataset() | ||||
data_set.set_input("x", flag=True) | data_set.set_input("x", flag=True) | ||||
data_set.set_target("y", flag=True) | data_set.set_target("y", flag=True) | ||||
model = NaiveClassifier(2, 2) | model = NaiveClassifier(2, 2) | ||||
trainer = DistTrainer( | trainer = DistTrainer( | ||||
model=model, train_data=data_set, optimizer=SGD(lr=0.1), | model=model, train_data=data_set, optimizer=SGD(lr=0.1), | ||||
loss=CrossEntropyLoss(pred="predict", target="y"), | loss=CrossEntropyLoss(pred="predict", target="y"), | ||||
@@ -105,7 +111,7 @@ class TestDistTrainer(unittest.TestCase): | |||||
""" | """ | ||||
if trainer.is_master and os.path.exists(self.save_path): | if trainer.is_master and os.path.exists(self.save_path): | ||||
shutil.rmtree(self.save_path) | shutil.rmtree(self.save_path) | ||||
def run3(self): | def run3(self): | ||||
set_rng_seed(100) | set_rng_seed(100) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
@@ -117,26 +123,28 @@ class TestDistTrainer(unittest.TestCase): | |||||
callbacks_master=[EchoCallback('callbacks_master')] | callbacks_master=[EchoCallback('callbacks_master')] | ||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
def run4(self): | def run4(self): | ||||
set_rng_seed(100) | set_rng_seed(100) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
train_set, dev_set = data_set.split(0.3) | train_set, dev_set = data_set.split(0.3) | ||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
trainer = DistTrainer( | trainer = DistTrainer( | ||||
train_set, model, optimizer=SGD(lr=0.1), | train_set, model, optimizer=SGD(lr=0.1), | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, | batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, | ||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=self.save_path, | |||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
""" | """ | ||||
if trainer.is_master and os.path.exists(self.save_path): | |||||
shutil.rmtree(self.save_path) | |||||
def run_dist(self, run_id): | def run_dist(self, run_id): | ||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
ngpu = min(2, torch.cuda.device_count()) | ngpu = min(2, torch.cuda.device_count()) | ||||
@@ -145,23 +153,24 @@ class TestDistTrainer(unittest.TestCase): | |||||
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | ||||
print(' '.join(cmd)) | print(' '.join(cmd)) | ||||
subprocess.check_call(cmd) | subprocess.check_call(cmd) | ||||
def test_normal_run(self): | def test_normal_run(self): | ||||
self.run_dist(1) | self.run_dist(1) | ||||
def no_test_fp16(self): | def no_test_fp16(self): | ||||
self.run_dist(2) | self.run_dist(2) | ||||
def test_callback(self): | def test_callback(self): | ||||
self.run_dist(3) | self.run_dist(3) | ||||
def test_dev_data(self): | def test_dev_data(self): | ||||
self.run_dist(4) | self.run_dist(4) | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
runner = TestDistTrainer() | runner = TestDistTrainer() | ||||
parser = ArgumentParser() | parser = ArgumentParser() | ||||
parser.add_argument('--test', type=int) | parser.add_argument('--test', type=int) | ||||
args, _ = parser.parse_known_args() | args, _ = parser.parse_known_args() | ||||
if args.test and hasattr(runner, 'run%s'%args.test): | |||||
getattr(runner, 'run%s'%args.test)() | |||||
if args.test and hasattr(runner, 'run%s' % args.test): | |||||
getattr(runner, 'run%s' % args.test)() |
@@ -5,7 +5,7 @@ import os | |||||
from fastNLP.io import DataBundle | from fastNLP.io import DataBundle | ||||
from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \ | from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \ | ||||
BQCorpusLoader, XNLILoader, LCQMCLoader | |||||
BQCorpusLoader, CNXNLILoader, LCQMCLoader | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
@@ -31,7 +31,7 @@ class TestMatchingLoad(unittest.TestCase): | |||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | ||||
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', XNLILoader, (6, 7, 6), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 7, 6), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | ||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
@@ -4,9 +4,9 @@ import os | |||||
from fastNLP.io import DataBundle | from fastNLP.io import DataBundle | ||||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \ | from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \ | ||||
XNLIPipe, BQCorpusPipe, LCQMCPipe | |||||
CNXNLIPipe, BQCorpusPipe, LCQMCPipe | |||||
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \ | from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \ | ||||
XNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe | |||||
CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
@@ -38,7 +38,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | ||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', XNLIPipe, XNLIBertPipe, (6, 7, 6), (37, 3), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 7, 6), (37, 3), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | ||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
@@ -1,59 +1,69 @@ | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import unittest | |||||
import os | |||||
# import sys | |||||
# | |||||
# sys.path.append("../../../") | |||||
from fastNLP.io import DataBundle | |||||
from fastNLP.io.pipe.summarization import ExtCNNDMPipe | |||||
class TestRunExtCNNDMPipe(unittest.TestCase): | |||||
def test_load(self): | |||||
data_set_dict = { | |||||
'CNNDM': {"train": 'test/data_for_tests/cnndm.jsonl'}, | |||||
} | |||||
vocab_size = 100000 | |||||
VOCAL_FILE = 'test/data_for_tests/cnndm.vocab' | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps) | |||||
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, | |||||
domain=True) | |||||
for k, v in data_set_dict.items(): | |||||
db = dbPipe.process_from_file(v) | |||||
db2 = dbPipe2.process_from_file(v) | |||||
# print(db2.get_dataset("train")) | |||||
self.assertTrue(isinstance(db, DataBundle)) | |||||
self.assertTrue(isinstance(db2, DataBundle)) | |||||
#!/usr/bin/python | |||||
# -*- coding: utf-8 -*- | |||||
# __author__="Danqing Wang" | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
import unittest | |||||
import os | |||||
from fastNLP.io import DataBundle | |||||
from fastNLP.io.pipe.summarization import ExtCNNDMPipe | |||||
class TestRunExtCNNDMPipe(unittest.TestCase): | |||||
def test_load(self): | |||||
data_dir = 'test/data_for_tests/io/cnndm' | |||||
vocab_size = 100000 | |||||
VOCAL_FILE = 'test/data_for_tests/io/cnndm/vocab' | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps) | |||||
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, | |||||
domain=True) | |||||
db = dbPipe.process_from_file(data_dir) | |||||
db2 = dbPipe2.process_from_file(data_dir) | |||||
self.assertTrue(isinstance(db, DataBundle)) | |||||
self.assertTrue(isinstance(db2, DataBundle)) | |||||
dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, | |||||
domain=True) | |||||
db3 = dbPipe3.process_from_file(data_dir) | |||||
self.assertTrue(isinstance(db3, DataBundle)) | |||||
with self.assertRaises(RuntimeError): | |||||
dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps) | |||||
db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||||
dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps,) | |||||
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||||
self.assertIsInstance(db5, DataBundle) | |||||
@@ -0,0 +1,25 @@ | |||||
import os | |||||
import unittest | |||||
from fastNLP.io import ModelSaver, ModelLoader | |||||
from fastNLP.models import CNNText | |||||
class TestModelIO(unittest.TestCase): | |||||
def test_save_and_load(self): | |||||
model = CNNText((10, 10), 2) | |||||
saver = ModelSaver('tmp') | |||||
loader = ModelLoader() | |||||
saver.save_pytorch(model) | |||||
new_cnn = CNNText((10, 10), 2) | |||||
loader.load_pytorch(new_cnn, 'tmp') | |||||
new_model = loader.load_pytorch_model('tmp') | |||||
for i in range(10): | |||||
for j in range(10): | |||||
self.assertEqual(model.embed.embed.weight[i, j], new_cnn.embed.embed.weight[i, j]) | |||||
self.assertEqual(model.embed.embed.weight[i, j], new_model["embed.embed.weight"][i, j]) | |||||
os.system('rm tmp') |
@@ -1,9 +1,20 @@ | |||||
import unittest | import unittest | ||||
import torch | import torch | ||||
from fastNLP.modules.utils import get_dropout_mask | |||||
from fastNLP.models import CNNText | |||||
from fastNLP.modules.utils import get_dropout_mask, summary | |||||
class TestUtil(unittest.TestCase): | class TestUtil(unittest.TestCase): | ||||
def test_get_dropout_mask(self): | def test_get_dropout_mask(self): | ||||
tensor = torch.randn(3, 4) | tensor = torch.randn(3, 4) | ||||
mask = get_dropout_mask(0.3, tensor) | mask = get_dropout_mask(0.3, tensor) | ||||
self.assertSequenceEqual(mask.size(), torch.Size([3, 4])) | |||||
self.assertSequenceEqual(mask.size(), torch.Size([3, 4])) | |||||
def test_summary(self): | |||||
model = CNNText(embed=(4, 4), num_classes=2, kernel_nums=(9,5), kernel_sizes=(1,3)) | |||||
# 4 * 4 + 4 * (9 * 1 + 5 * 3) + 2 * (9 + 5 + 1) = 142 | |||||
self.assertSequenceEqual((142, 142, 0), summary(model)) | |||||
model.embed.requires_grad = False | |||||
self.assertSequenceEqual((142, 126, 16), summary(model)) |