From 372991c03a65c02c680e64fe54064da972188831 Mon Sep 17 00:00:00 2001 From: benbijituo Date: Fri, 20 Sep 2019 11:23:50 +0800 Subject: [PATCH 01/16] =?UTF-8?q?=E8=A1=A5=E5=85=85=E4=BA=86=E4=B8=A4?= =?UTF-8?q?=E4=B8=AA=E6=95=B0=E6=8D=AE=E9=9B=86=E7=9A=84download?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/file_utils.py | 4 +++- fastNLP/io/loader/classification.py | 9 +++++++++ fastNLP/io/loader/matching.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 6661397b..022af0ac 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -89,6 +89,7 @@ DATASET_DIR = { "mnli": "MNLI.zip", "snli": "SNLI.zip", "qnli": "QNLI.zip", + "xnli": "XNLI.zip", "sst-2": "SST-2.zip", "sst": "SST.zip", "rte": "RTE.zip", @@ -101,7 +102,8 @@ DATASET_DIR = { "cws-as": 'cws_as.zip', "cws-msra": 'cws_msra.zip', - "chn-senti-corp":"chn_senti_corp.zip" + "chn-senti-corp" : "chn_senti_corp.zip", + "weibo-senti-100k" : "WeiboSenti100k.zip" } PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index 51660db5..ca9b6107 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -518,3 +518,12 @@ class WeiboSenti100kLoader(Loader): if raw_chars: ds.append(Instance(raw_chars=raw_chars, target=target)) return ds + + def download(self) -> str: + """ + 自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ + 在 https://arxiv.org/abs/1906.08101 有使用 + :return: + """ + output_dir = self._get_dataset_path('weibo-senti-100k') + return output_dir diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index df60618b..b9724126 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -377,6 +377,16 @@ class XNLILoader(Loader): data_bundle = DataBundle(datasets=datasets) return data_bundle + def download(self) -> str: + """ + 自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 + 在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf + https://arxiv.org/pdf/1809.05053.pdf 有使用 + :return: + """ + output_dir = self._get_dataset_path('xnli') + return output_dir + class BQCorpusLoader(Loader): """ From 02cfc9f421e7bbb9e727d92c323ead6f471f00e7 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 20 Sep 2019 15:44:53 +0800 Subject: [PATCH 02/16] [add] docstring in batch, dist_trainer; [update] dist_trainer, callback --- fastNLP/core/batch.py | 21 +++++- fastNLP/core/callback.py | 49 ++++++------ fastNLP/core/dist_trainer.py | 139 +++++++++++++++++++++++++++++------ fastNLP/core/trainer.py | 1 + fastNLP/core/utils.py | 11 +++ 5 files changed, 170 insertions(+), 51 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 4ee1916a..f2e34c52 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -193,13 +193,14 @@ class DataSetIter(BatchIter): Default: ``None`` :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. - + Default: ``False`` :param int num_workers: 使用多少个进程来预处理数据 :param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 - :param timeout: + :param timeout: 生成一个batch的timeout值 :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 + :param collate_fn: 用于将样本组合成batch的函数 """ assert isinstance(dataset, DataSet) dataset = DataSetGetter(dataset, as_numpy) @@ -220,12 +221,26 @@ class DataSetIter(BatchIter): class TorchLoaderIter(BatchIter): """ - 与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 + 与DataSetIter类似,但用于pytorch的DataSet对象。 + 通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 """ def __init__(self, dataset, batch_size=1, sampler=None, num_workers=0, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, collate_fn=None): + """ + + :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 + :param int batch_size: 取出的batch大小 + :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. + + Default: ``None`` + :param int num_workers: 使用多少个进程来预处理数据 + :param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 + :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 + :param timeout: 生成一个batch的timeout值 + :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 + :param collate_fn: 用于将样本组合成batch的函数""" assert len(dataset) > 0 ins = dataset[0] assert len(ins) == 2 and \ diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 6ad98b0b..734c1269 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -87,12 +87,18 @@ except: from .dataset import DataSet from .tester import Tester from ._logger import logger +from .utils import _check_fp16 try: import fitlog except: pass +try: + from apex import amp +except: + amp = None + class Callback(object): """ @@ -269,14 +275,6 @@ class Callback(object): :return: """ pass - - def on_validation(self): - """ - 如果Trainer中设置了验证,则会在每次需要验证时调用该函数 - - :return: - """ - pass def on_epoch_end(self): """ @@ -470,7 +468,7 @@ class GradientClipCallback(Callback): if self.step%self.update_every==0: if self.parameters is None: if getattr(self.trainer, 'fp16', ''): - from apex import amp + _check_fp16() self.clip_fun(amp.master_params(self.optimizer), self.clip_value) self.clip_fun(self.model.parameters(), self.clip_value) else: @@ -1036,27 +1034,23 @@ class EchoCallback(Callback): return super(EchoCallback, self).__getattribute__(item) -class TesterCallback(Callback): +class _TesterCallback(Callback): def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): - super(TesterCallback, self).__init__() + super(_TesterCallback, self).__init__() if hasattr(model, 'module'): # for data parallel model model = model.module self.tester = Tester(data, model, metrics=metrics, batch_size=batch_size, num_workers=num_workers, verbose=0) - # parse metric_key - # increase_better is True. It means the exp result gets better if the indicator increases. - # It is true by default. - self.increase_better = True if metric_key is not None: - self.increase_better = False if metric_key[0] == "-" else True - self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + self.metric_key, self.increase_better = self._parse_metric_key(metric_key) else: self.metric_key = None + self.increase_better = True self.score = None - def on_validation(self): + def on_valid_begin(self): cur_score = self.tester.test() eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( self.epoch, self.n_epochs, self.step, self.n_steps, @@ -1067,17 +1061,28 @@ class TesterCallback(Callback): self.score = cur_score return cur_score, is_better - def _get_score(self, metric_dict, key): + @staticmethod + def _get_score(metric_dict, key): for metric in metric_dict.items(): if key in metric: return metric[key] return None + @staticmethod + def _parse_metric_key(metric_key): + # parse metric_key + # increase_better is True. It means the exp result gets better if the indicator increases. + # It is true by default. + increase_better = False if metric_key[0] == "-" else True + metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + return metric_key, increase_better + def compare_better(self, a): if self.score is None: return True if self.metric_key is None: - self.metric_key = list(list(self.score.values())[0].keys())[0] + metric_key = list(list(self.score.values())[0].keys())[0] + self.metric_key, self.increase_better = self._parse_metric_key(metric_key) k = self.metric_key score = self._get_score(self.score, k) new_score = self._get_score(a, k) @@ -1087,7 +1092,3 @@ class TesterCallback(Callback): return score <= new_score else: return score >= new_score - - def on_train_end(self): - self.logger.info('Evaluate on training ends.') - self.on_validation() diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 3a293447..c2804134 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -17,21 +17,30 @@ from tqdm import tqdm from ._logger import logger from .batch import DataSetIter, BatchIter -from .callback import DistCallbackManager, CallbackException, TesterCallback +from .callback import DistCallbackManager, CallbackException, _TesterCallback from .dataset import DataSet from .losses import _prepare_losser from .optimizer import Optimizer from .utils import _build_args from .utils import _get_func_signature from .utils import _move_dict_value_to_device +from .utils import _check_fp16 + + +try: + from apex import amp +except: + amp = None __all__ = [ 'get_local_rank', 'DistTrainer', ] - def get_local_rank(): + """ + 返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数 + """ if 'LOCAL_RANK' in os.environ: return int(os.environ['LOCAL_RANK']) from argparse import ArgumentParser @@ -46,7 +55,10 @@ def get_local_rank(): class DistTrainer(): """ - Distributed Trainer that support distributed and mixed precision training + 分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 + + Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, + 请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等) """ def __init__(self, train_data, model, optimizer=None, loss=None, callbacks_all=None, callbacks_master=None, @@ -55,8 +67,43 @@ class DistTrainer(): dev_data=None, metrics=None, metric_key=None, update_every=1, print_every=10, validate_every=-1, save_every=-1, save_path=None, device='auto', - fp16='', backend=None, init_method=None): + fp16='', backend=None, init_method=None, use_tqdm=True): + """ + :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 + :param nn.modules model: 待训练的模型 + :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 + :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` + :param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 + 可使用的callback参见 :doc:`callback模块 ` + :param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。 + 可使用的callback参见 :doc:`callback模块 ` + :param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。 + :param int n_epochs: 需要优化迭代多少次。 + :param num_workers: int, 有多少个线程来进行数据pad处理。 + :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch + :param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 + :param metrics: 验证的评估函数。可以只使用一个 :class:`Metric` , + 也可以使用多个 :class:`Metric` ,通过列表传入。 + 如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, + 则保存当前模型。Metric种类详见 :doc:`metrics模块 ` 。仅在传入dev_data时有效。 + :param str,None metric_key: :class:`Metric` 有时会有多个指标, + 比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 + 要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 + 明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 + :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 + 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 + :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 + :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 + :param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。 + :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 + 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 + :param str device: 指定 device,可以是 gpu,cpu 或 auto + :param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 + :param backend: 指定分布式的backend,详情参考 pytorch 文档 + :param init_method 指定分布式的初始化方法,详情参考 pytorch 文档 + :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 + """ assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" if device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -94,7 +141,9 @@ class DistTrainer(): self.callback_manager = DistCallbackManager( env={"trainer": self}, callbacks_all=callbacks_all, callbacks_master=callbacks_master) + self.test_manager = DistCallbackManager(env={'trainer': self}) self.metric_key = metric_key + self.use_tqdm = use_tqdm model.to(self.device) optimizer = self._get_optimizer(optimizer) @@ -102,11 +151,7 @@ class DistTrainer(): # init fp16, must before DataParallel init if len(self.fp16): assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." + _check_fp16() assert device == 'cuda', "Amp requires cuda device" model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) @@ -121,14 +166,15 @@ class DistTrainer(): self.optimizer = optimizer self.sampler = DistributedSampler(self.train_data) self.data_iterator = self._get_data_iter(self.train_data) + self.batch_size = self.world_size * self.batch_size_per_gpu self.n_steps = self._get_n_steps() # for evaluation, only run eval on master proc if dev_data and metrics: - cb = TesterCallback( + cb = _TesterCallback( dev_data, model, metrics, batch_size=batch_size_per_gpu, num_workers=num_workers) - self.callback_manager.add_callback([cb], master=True) + self.test_manager.add_callback([cb], master=True) # Setup logging dist.barrier() @@ -178,9 +224,27 @@ class DistTrainer(): @property def is_master(self): + """是否是主进程""" return self.rank == 0 - def train(self, on_exception='auto'): + def train(self, load_best_model=True, on_exception='auto'): + """ + 使用该函数使Trainer开始训练。 + + :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 + 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; + 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. + :return dict: 返回一个字典类型的数据, + 内含以下内容:: + + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称, + 第二层的key为具体的Metric + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 + + """ try: self.logger.info("###### Training epochs started ######") self.logger.info('Total epochs: %d'% self.n_epochs) @@ -222,17 +286,22 @@ class DistTrainer(): results['seconds'] = round(time.time() - start_time, 2) self.logger.info("###### Train finished ######") self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) - return results + if load_best_model: + self.load_check_point('best_{}'.format(self.metric_key)) finally: - self.close() + pass + + return results def _train(self): - if self.fp16: - # skip check, done in __init__() - from apex import amp + if not self.use_tqdm: + from .utils import _pseudo_tqdm as inner_tqdm + else: + inner_tqdm = tqdm + self.step = 0 self.epoch = 0 - self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', + self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, disable=not self.is_master) pbar = self.pbar avg_loss = 0 @@ -292,8 +361,8 @@ class DistTrainer(): if self.validate_every < 0: self._do_validation() - if self.save_every < 0 and self.cp_save_path: - self.save_check_point() + if self.save_every < 0 and self.cp_save_path: + self.save_check_point() # lr decay; early stopping self.callback_manager.on_epoch_end() # =============== epochs end =================== # @@ -327,22 +396,35 @@ class DistTrainer(): loss = self.losser(predict, truth) if self.update_every > 1: loss = loss / self.update_every - return loss.mean() + if loss.dim() > 0: + loss = loss.mean() + return loss - def save_check_point(self, only_params=False): + def save_check_point(self, name=None, only_params=False): + """保存当前模型""" # only master save models if self.is_master: + if name is None: + name = 'checkpoint-{}.bin'.format(self.step) os.makedirs(self.cp_save_path, exist_ok=True) - path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) + path = os.path.join(self.cp_save_path, name) self.logger.info("Save checkpoint to {}".format(path)) model_to_save = self.model.module if only_params: model_to_save = model_to_save.state_dict() torch.save(model_to_save, path) + def load_check_point(self, name): + path = os.path.join(self.cp_save_path, name) + self.logger.info('reload best model from %s', path) + model_load = torch.load(path) + if not isinstance(model_load, dict): + model_load = model_load.state_dict() + self.model.load_state_dict(model_load) + def _do_validation(self): self.callback_manager.on_valid_begin() - eval_res = self.callback_manager.on_validation() + eval_res = self.test_manager.on_valid_begin() eval_res = list(filter(lambda x: x is not None, eval_res)) if len(eval_res): eval_res, is_better = list(zip(*eval_res)) @@ -350,7 +432,16 @@ class DistTrainer(): eval_res, is_better = None, None self.callback_manager.on_valid_end( eval_res, self.metric_key, self.optimizer, is_better) + + # save better model + for i, better_flag in enumerate(is_better): + if better_flag: + # TODO to support multiple datasets to evaluate + name = 'best_{}'.format(self.metric_key) + self.save_check_point(name) + break dist.barrier() def close(self): + """关闭Trainer,销毁进程""" dist.destroy_process_group() diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a2c3b1f7..c331ab18 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -842,6 +842,7 @@ class Trainer(object): @property def is_master(self): + """是否是主进程""" return True DEFAULT_CHECK_BATCH_SIZE = 2 diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index dd2afab7..d5ae563c 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -19,6 +19,10 @@ import torch.nn as nn from typing import List from ._logger import logger from prettytable import PrettyTable +try: + from apex import amp +except: + amp = None _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs']) @@ -805,3 +809,10 @@ def sub_column(string: str, c: int, c_size: int, title: str) -> str: if len(string) > avg: string = string[:(avg - 3)] + "..." return string + + +def _check_fp16(): + if amp is None: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + if not torch.backends.cudnn.enabled: + raise RuntimeError("Amp requires cudnn backend to be enabled.") From 05eb499eb893135d856d7eb440b1b1e1bd244956 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 20 Sep 2019 16:30:37 +0800 Subject: [PATCH 03/16] [bugfix] dist_trainer's save & load --- fastNLP/core/__init__.py | 3 +-- fastNLP/core/callback.py | 2 +- fastNLP/core/dist_trainer.py | 33 +++++++++++++++++---------------- test/core/test_dist_trainer.py | 4 +++- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index efee08b5..bea80097 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -49,7 +49,6 @@ __all__ = [ "WarmupCallback", 'SaveModelCallback', "EchoCallback", - "TesterCallback", "CallbackException", "EarlyStopError", @@ -79,7 +78,7 @@ from ._logger import logger from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ - TesterCallback, CallbackException, EarlyStopError + CallbackException, EarlyStopError from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 734c1269..fac1f1f4 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -63,7 +63,7 @@ __all__ = [ "WarmupCallback", "SaveModelCallback", "EchoCallback", - "TesterCallback", + "_TesterCallback", "CallbackException", "EarlyStopError" diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index c2804134..2451911d 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -17,7 +17,8 @@ from tqdm import tqdm from ._logger import logger from .batch import DataSetIter, BatchIter -from .callback import DistCallbackManager, CallbackException, _TesterCallback +from .callback import DistCallbackManager, CallbackException +from .callback import _TesterCallback from .dataset import DataSet from .losses import _prepare_losser from .optimizer import Optimizer @@ -174,13 +175,13 @@ class DistTrainer(): cb = _TesterCallback( dev_data, model, metrics, batch_size=batch_size_per_gpu, num_workers=num_workers) - self.test_manager.add_callback([cb], master=True) + self.test_manager.add_callback([cb], master=False) # Setup logging dist.barrier() self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') if self.save_path: - self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) + self.cp_save_path = os.path.join(self.save_path, 'checkpoints') else: self.cp_save_path = None @@ -286,11 +287,11 @@ class DistTrainer(): results['seconds'] = round(time.time() - start_time, 2) self.logger.info("###### Train finished ######") self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) - if load_best_model: - self.load_check_point('best_{}'.format(self.metric_key)) + if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): + self.load_check_point('best') finally: pass - + dist.barrier() return results def _train(self): @@ -417,29 +418,29 @@ class DistTrainer(): def load_check_point(self, name): path = os.path.join(self.cp_save_path, name) self.logger.info('reload best model from %s', path) - model_load = torch.load(path) + model_load = torch.load(path, map_location='cpu') if not isinstance(model_load, dict): model_load = model_load.state_dict() - self.model.load_state_dict(model_load) + self.model.module.load_state_dict(model_load) def _do_validation(self): self.callback_manager.on_valid_begin() + # do evaluate on all nodes eval_res = self.test_manager.on_valid_begin() eval_res = list(filter(lambda x: x is not None, eval_res)) if len(eval_res): eval_res, is_better = list(zip(*eval_res)) else: eval_res, is_better = None, None + # save better model on master node + if self.is_master and is_better is not None and self.cp_save_path: + for i, better_flag in enumerate(is_better): + if better_flag: + # TODO to support multiple datasets to evaluate + self.save_check_point('best') + break self.callback_manager.on_valid_end( eval_res, self.metric_key, self.optimizer, is_better) - - # save better model - for i, better_flag in enumerate(is_better): - if better_flag: - # TODO to support multiple datasets to evaluate - name = 'best_{}'.format(self.metric_key) - self.save_check_point(name) - break dist.barrier() def close(self): diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py index c6879634..03f613e1 100644 --- a/test/core/test_dist_trainer.py +++ b/test/core/test_dist_trainer.py @@ -130,12 +130,14 @@ class TestDistTrainer(unittest.TestCase): train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, - metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, + metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=self.save_path, ) trainer.train() """ # 应该正确运行 """ + if trainer.is_master and os.path.exists(self.save_path): + shutil.rmtree(self.save_path) def run_dist(self, run_id): if torch.cuda.is_available(): From 7413276997731b3e816444c1db3caf624b743405 Mon Sep 17 00:00:00 2001 From: zide05 <845465009@qq.com> Date: Sun, 22 Sep 2019 09:47:33 +0800 Subject: [PATCH 04/16] modify pipe documents --- fastNLP/io/__init__.py | 7 ++ fastNLP/io/pipe/__init__.py | 4 +- fastNLP/io/pipe/classification.py | 162 ++++++++++++++++++++++++------ fastNLP/io/pipe/conll.py | 103 +++++++++++++++---- fastNLP/io/pipe/coreference.py | 30 ++++-- fastNLP/io/pipe/cws.py | 19 +++- fastNLP/io/pipe/matching.py | 51 ++++++++-- 7 files changed, 303 insertions(+), 73 deletions(-) diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index c8b3dfaa..63fde69a 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -25,6 +25,8 @@ __all__ = [ 'SSTLoader', 'SST2Loader', "ChnSentiCorpLoader", + "THUCNewsLoader", + "WeiboSenti100kLoader", 'ConllLoader', 'Conll2003Loader', @@ -45,6 +47,9 @@ __all__ = [ "SNLILoader", "QNLILoader", "RTELoader", + "XNLILoader", + "BQCorpusLoader", + "LCQMCLoader", "Pipe", @@ -54,6 +59,8 @@ __all__ = [ "SST2Pipe", "IMDBPipe", "ChnSentiCorpPipe", + "THUCNewsPipe", + "WeiboSenti100kPipe", "Conll2003Pipe", "Conll2003NERPipe", diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 0ddb1f2d..212f9e66 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -18,6 +18,8 @@ __all__ = [ "SST2Pipe", "IMDBPipe", "ChnSentiCorpPipe", + "THUCNewsPipe", + "WeiboSenti100kPipe", "Conll2003NERPipe", "OntoNotesNERPipe", @@ -42,7 +44,7 @@ __all__ = [ "CoReferencePipe" ] -from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe +from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 409cfe53..1c44cc23 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -97,11 +97,22 @@ class YelpFullPipe(_CLSPipe): 处理YelpFull的数据, 处理之后DataSet中的内容如下 .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field - :header: "raw_words", "words", "target", "seq_len" + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . - "It 's a ...", "[4, 2, 10, ...]", 0, 10 - "Offers that ...", "[20, 40, ...]", 1, 21 - "...", "[...]", ., . + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ """ @@ -193,11 +204,22 @@ class YelpPolarityPipe(_CLSPipe): 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field - :header: "raw_words", "words", "target", "seq_len" + :header: "raw_words", "target", "words", "seq_len" - "It 's a ...", "[4, 2, 10, ...]", 0, 10 - "Offers that ...", "[20, 40, ...]", 1, 21 - "...", "[...]", ., . + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ """ @@ -211,6 +233,19 @@ class YelpPolarityPipe(_CLSPipe): self.lower = lower def process(self, data_bundle): + """ + 传入的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + :param data_bundle: + :return: + """ # 复制一列words data_bundle = _add_words_field(data_bundle, lower=self.lower) @@ -244,9 +279,20 @@ class SSTPipe(_CLSPipe): .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field :header: "raw_words", "words", "target", "seq_len" - "It 's a ...", "[4, 2, 10, ...]", 0, 16 - "Offers that ...", "[20, 40, ...]", 1, 18 - "...", "[...]", ., . + "It 's a lovely film with lovely perfor...", 1, "[187, 6, 5, 132, 120, 70, 132, 188, 25...", 13 + "No one goes unindicted here , which is...", 0, "[191, 126, 192, 193, 194, 4, 195, 17, ...", 13 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ """ @@ -278,11 +324,11 @@ class SSTPipe(_CLSPipe): """ 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 - .. csv-table:: + .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field :header: "raw_words" - "(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." - "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." + "(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." + "(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." "..." :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 @@ -335,12 +381,23 @@ class SST2Pipe(_CLSPipe): 加载SST2的数据, 处理完成之后DataSet将拥有以下的field .. csv-table:: - :header: "raw_words", "words", "target", "seq_len" + :header: "raw_words", "target", "words", "seq_len" - "it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43 - "unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21 + "it 's a charming and often affecting j... ", 1, "[19, 9, 6, 111, 5, 112, 113, 114, 3]", 9 + "unflinchingly bleak and desperate", 0, "[115, 116, 5, 117]", 4 "...", "...", ., . + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ def __init__(self, lower=False, tokenizer='spacy'): @@ -357,11 +414,11 @@ class SST2Pipe(_CLSPipe): 可以处理的DataSet应该具备如下的结构 .. csv-table:: - :header: "raw_words", "target" + :header: "raw_words", "target" - "it 's a charming and... ", 1 - "unflinchingly bleak and...", 1 - "...", "..." + "it 's a charming and often affecting...", "1" + "unflinchingly bleak and...", "0" + "..." :param data_bundle: :return: @@ -420,15 +477,26 @@ class IMDBPipe(_CLSPipe): 经过本Pipe处理后DataSet将如下 .. csv-table:: 输出DataSet的field - :header: "raw_words", "words", "target", "seq_len" + :header: "raw_words", "target", "words", "seq_len" - "Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20 - "Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31 - "...", "[...]", ., . + "Bromwell High is a cartoon ... ", 0, "[3, 5, 6, 9, ...]", 20 + "Story of a man who has ...", 1, "[20, 43, 9, 10, ...]", 31 + "...", ., "[...]", . 其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; words列被设置为input; target列被设置为target。 + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): @@ -493,13 +561,23 @@ class ChnSentiCorpPipe(Pipe): 处理之后的DataSet有以下的结构 .. csv-table:: - :header: "raw_chars", "chars", "target", "seq_len" + :header: "raw_chars", "target", "chars", "seq_len" - "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31 - "<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25 + "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", 1, "[2, 3, 4, 5, ...]", 31 + "<荐书> 推荐所有喜欢<红楼>...", 1, "[10, 21, ....]", 25 "..." 其中chars, seq_len是input,target是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ """ def __init__(self, bigrams=False, trigrams=False): @@ -590,12 +668,22 @@ class THUCNewsPipe(_CLSPipe): 处理之后的DataSet有以下的结构 .. csv-table:: - :header: "raw_chars", "chars", "target", "seq_len" + :header: "raw_chars", "target", "chars", "seq_len" - "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", "[409, 1197, 2146, 213, ...]", 0, 746 + "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", 0, "[409, 1197, 2146, 213, ...]", 746 "..." 其中chars, seq_len是input,target是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 @@ -691,12 +779,22 @@ class WeiboSenti100kPipe(_CLSPipe): 处理之后的DataSet有以下的结构 .. csv-table:: - :header: "raw_chars", "chars", "target", "seq_len" + :header: "raw_chars", "target", "chars", "seq_len" - "六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "[0, 690, 18, ...]", 0, 56 + "六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", 0, "[0, 690, 18, ...]", 56 "..." 其中chars, seq_len是input,target是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index 70af5acb..918cff9f 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -87,15 +87,26 @@ class Conll2003NERPipe(_NERPipe): 经过该Pipe过后,DataSet中的内容如下所示 .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader - :header: "raw_words", "words", "target", "seq_len" + :header: "raw_words", "target", "words", "seq_len" - "[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 - "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4,...]", 6 + "[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4,...]", "[4, 5, 6,...]", 6 "[...]", "[...]", "[...]", . raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ def process_from_file(self, paths) -> DataBundle: @@ -112,17 +123,28 @@ class Conll2003NERPipe(_NERPipe): class Conll2003Pipe(Pipe): - r""" + """ 经过该Pipe后,DataSet中的内容如下 .. csv-table:: - :header: "raw_words" , "words", "pos", "chunk", "ner", "seq_len" + :header: "raw_words" , "pos", "chunk", "ner", "words", "seq_len" - "[Nadim, Ladki]", "[2, 3]", "[0, 0]", "[1, 2]", "[1, 2]", 2 - "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", 6 + "[Nadim, Ladki]", "[0, 0]", "[1, 2]", "[1, 2]", "[2, 3]", 2 + "[AL-AIN, United, Arab, ...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", "[4, 5, 6,...]", 6 "[...]", "[...]", "[...]", "[...]", "[...]", . 其中words, seq_len是input; pos, chunk, ner, seq_len是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+-------+-------+-------+-------+---------+ + | field_names | raw_words | pos | chunk | ner | words | seq_len | + +-------------+-----------+-------+-------+-------+-------+---------+ + | is_input | False | False | False | False | True | True | + | is_target | False | True | True | True | False | True | + | ignore_type | | False | False | False | False | False | + | pad_value | | 0 | 0 | 0 | 0 | 0 | + +-------------+-----------+-------+-------+-------+-------+---------+ + """ def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): @@ -202,15 +224,26 @@ class OntoNotesNERPipe(_NERPipe): 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 .. csv-table:: - :header: "raw_words", "words", "target", "seq_len" + :header: "raw_words", "target", "words", "seq_len" - "[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 - "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4]", 6 + "[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4]", "[4, 5, 6,...]", 6 "[...]", "[...]", "[...]", . raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ def process_from_file(self, paths): @@ -306,15 +339,26 @@ class MsraNERPipe(_CNNERPipe): 处理MSRA-NER的数据,处理之后的DataSet的field情况为 .. csv-table:: - :header: "raw_chars", "chars", "target", "seq_len" + :header: "raw_chars", "target", "chars", "seq_len" - "[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 - "[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 + "[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 + "[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 "[...]", "[...]", "[...]", . raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ def process_from_file(self, paths=None) -> DataBundle: @@ -327,14 +371,26 @@ class PeopleDailyPipe(_CNNERPipe): 处理people daily的ner的数据,处理之后的DataSet的field情况为 .. csv-table:: - :header: "raw_chars", "chars", "target", "seq_len" + :header: "raw_chars", "target", "chars", "seq_len" - "[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 - "[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 + "[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 + "[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 "[...]", "[...]", "[...]", . raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ def process_from_file(self, paths=None) -> DataBundle: @@ -349,13 +405,24 @@ class WeiboNERPipe(_CNNERPipe): .. csv-table:: :header: "raw_chars", "chars", "target", "seq_len" - "[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11 - "[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21 + "['老', '百', '姓']", "[4, 3, 3]", "[38, 39, 40]", 3 + "['心']", "[0]", "[41]", 1 "[...]", "[...]", "[...]", . raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ def process_from_file(self, paths=None) -> DataBundle: diff --git a/fastNLP/io/pipe/coreference.py b/fastNLP/io/pipe/coreference.py index c1b218a5..0cf6c996 100644 --- a/fastNLP/io/pipe/coreference.py +++ b/fastNLP/io/pipe/coreference.py @@ -18,9 +18,29 @@ from ...core.const import Const class CoReferencePipe(Pipe): """ 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 + + 处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: + + .. csv-table:: + :header: "words1", "words2","words3","words4","chars","seq_len","target" + + "bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" + "[...]", "[...]","[...]","[...]","[...]","[...]","[...]" + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + """ - def __init__(self,config): + def __init__(self, config): super().__init__() self.config = config @@ -35,14 +55,6 @@ class CoReferencePipe(Pipe): "bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" "[...]", "[...]","[...]","[...]" - 处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: - - .. csv-table:: - :header: "words1", "words2","words3","words4","chars","seq_len","target" - - "bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" - "[...]", "[...]","[...]","[...]","[...]","[...]","[...]" - :param data_bundle: :return: diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py index 97bda896..a2f2e7a2 100644 --- a/fastNLP/io/pipe/cws.py +++ b/fastNLP/io/pipe/cws.py @@ -138,13 +138,22 @@ class CWSPipe(Pipe): 对CWS数据进行预处理, 处理之后的数据,具备以下的结构 .. csv-table:: - :header: "raw_words", "chars", "target", "bigrams", "trigrams", "seq_len" + :header: "raw_words", "chars", "target", "seq_len" - "共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", "[10, 4, 1,...]","[6, 4, 1,...]", 13 - "2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", "[11, 12, ...]","[3, 9, ...]", 20 - "...", "[...]","[...]", "[...]","[...]", . + "共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", 13 + "2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", 20 + "...", "[...]","[...]", . - 其中bigrams仅当bigrams列为True的时候存在 + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+-------+--------+---------+ + | field_names | raw_words | chars | target | seq_len | + +-------------+-----------+-------+--------+---------+ + | is_input | False | True | True | True | + | is_target | False | False | True | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+-------+--------+---------+ """ diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index def750c0..7747dec3 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -37,16 +37,27 @@ class MatchingBertPipe(Pipe): Matching任务的Bert pipe,输出的DataSet将包含以下的field .. csv-table:: - :header: "raw_words1", "raw_words2", "words", "target", "seq_len" + :header: "raw_words1", "raw_words2", "target", "words", "seq_len" - "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10 - "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5 - "...", "...", "[...]", ., . + "The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", 10 + "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", 5 + "...", "...", ., "[...]", . words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+------------+------------+--------+-------+---------+ + | field_names | raw_words1 | raw_words2 | target | words | seq_len | + +-------------+------------+------------+--------+-------+---------+ + | is_input | False | False | False | True | True | + | is_target | False | False | True | False | False | + | ignore_type | | | False | False | False | + | pad_value | | | 0 | 0 | 0 | + +-------------+------------+------------+--------+-------+---------+ + """ def __init__(self, lower=False, tokenizer: str = 'raw'): @@ -75,6 +86,18 @@ class MatchingBertPipe(Pipe): return data_bundle def process(self, data_bundle): + """ + 输入的data_bundle中的dataset需要具有以下结构: + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" + "...","..." + + :param data_bundle: + :return: + """ for dataset in data_bundle.datasets.values(): if dataset.has_field(Const.TARGET): dataset.drop(lambda x: x[Const.TARGET] == '-') @@ -178,15 +201,27 @@ class MatchingPipe(Pipe): Matching任务的Pipe。输出的DataSet将包含以下的field .. csv-table:: - :header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2" + :header: "raw_words1", "raw_words2", "target", "words1", "words2", "seq_len1", "seq_len2" - "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13 - "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 - "...", "...", "[...]", "[...]", ., ., . + "The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", "[10, 20, 6]", 10, 13 + "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", "[2, 7, ...]", 6, 7 + "...", "...", ., "[...]", "[...]", ., . words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 的形参名进行传参)。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+------------+------------+--------+--------+--------+----------+----------+ + | field_names | raw_words1 | raw_words2 | target | words1 | words2 | seq_len1 | seq_len2 | + +-------------+------------+------------+--------+--------+--------+----------+----------+ + | is_input | False | False | False | True | True | True | True | + | is_target | False | False | True | False | False | False | False | + | ignore_type | | | False | False | False | False | False | + | pad_value | | | 0 | 0 | 0 | 0 | 0 | + +-------------+------------+------------+--------+--------+--------+----------+----------+ + """ def __init__(self, lower=False, tokenizer: str = 'raw'): From b874fba8f2d958ea0848e48f88ba1813069f80ec Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 09:51:19 +0800 Subject: [PATCH 05/16] add the test for modules.utils.summary --- test/modules/test_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/modules/test_utils.py b/test/modules/test_utils.py index 73226f97..340fedd9 100644 --- a/test/modules/test_utils.py +++ b/test/modules/test_utils.py @@ -1,9 +1,20 @@ import unittest + import torch -from fastNLP.modules.utils import get_dropout_mask + +from fastNLP.models import CNNText +from fastNLP.modules.utils import get_dropout_mask, summary + class TestUtil(unittest.TestCase): def test_get_dropout_mask(self): tensor = torch.randn(3, 4) mask = get_dropout_mask(0.3, tensor) - self.assertSequenceEqual(mask.size(), torch.Size([3, 4])) \ No newline at end of file + self.assertSequenceEqual(mask.size(), torch.Size([3, 4])) + + def test_summary(self): + model = CNNText(embed=(4, 4), num_classes=2, kernel_nums=(9,5), kernel_sizes=(1,3)) + # 4 * 4 + 4 * (9 * 1 + 5 * 3) + 2 * (9 + 5 + 1) = 142 + self.assertSequenceEqual((142, 142, 0), summary(model)) + model.embed.requires_grad = False + self.assertSequenceEqual((142, 126, 16), summary(model)) From 8de495d046ce4d37f788c4ef2e1e8ef1df7bcc4e Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 10:08:57 +0800 Subject: [PATCH 06/16] remove _TesterCallback from __all__ --- docs/source/fastNLP.core.callback.rst | 2 +- docs/source/fastNLP.modules.encoder.rst | 2 +- docs/source/fastNLP.modules.rst | 2 +- docs/source/fastNLP.rst | 2 +- fastNLP/core/callback.py | 1 - 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source/fastNLP.core.callback.rst b/docs/source/fastNLP.core.callback.rst index d37ddb11..75b5d0cd 100644 --- a/docs/source/fastNLP.core.callback.rst +++ b/docs/source/fastNLP.core.callback.rst @@ -2,6 +2,6 @@ fastNLP.core.callback ===================== .. automodule:: fastNLP.core.callback - :members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError + :members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError :inherited-members: diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst index cca62d05..a402cb67 100644 --- a/docs/source/fastNLP.modules.encoder.rst +++ b/docs/source/fastNLP.modules.encoder.rst @@ -2,5 +2,5 @@ fastNLP.modules.encoder ======================= .. automodule:: fastNLP.modules.encoder - :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, BiAttention, SelfAttention + :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, BiAttention, SelfAttention diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst index b7c259ed..9c44e461 100644 --- a/docs/source/fastNLP.modules.rst +++ b/docs/source/fastNLP.modules.rst @@ -2,7 +2,7 @@ fastNLP.modules =============== .. automodule:: fastNLP.modules - :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout + :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout 子模块 ------ diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index 95d77705..e92807d7 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -2,7 +2,7 @@ fastNLP ======= .. automodule:: fastNLP - :members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger + :members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger :inherited-members: 子模块 diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index fac1f1f4..36ce2aa5 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -63,7 +63,6 @@ __all__ = [ "WarmupCallback", "SaveModelCallback", "EchoCallback", - "_TesterCallback", "CallbackException", "EarlyStopError" From 9555d471a969e59dd2e13c82f144cc6133b3c24a Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 10:13:11 +0800 Subject: [PATCH 07/16] add links for variational RNN --- fastNLP/__init__.py | 1 - fastNLP/modules/encoder/variational_rnn.py | 10 +++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 8800a5a7..ded83308 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -37,7 +37,6 @@ __all__ = [ "WarmupCallback", 'SaveModelCallback', "EchoCallback", - "TesterCallback", "CallbackException", "EarlyStopError", diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index 5f4a5534..b09b3af9 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -1,5 +1,6 @@ """undocumented -Variational RNN 的 Pytorch 实现 +Variational RNN 及相关模型的 fastNLP实现,相关论文参考: +`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ """ __all__ = [ @@ -227,6 +228,7 @@ class VarRNNBase(nn.Module): class VarLSTM(VarRNNBase): """ Variational Dropout LSTM. + 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ """ @@ -253,7 +255,8 @@ class VarLSTM(VarRNNBase): class VarRNN(VarRNNBase): """ Variational Dropout RNN. - + 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ + """ def __init__(self, *args, **kwargs): @@ -279,7 +282,8 @@ class VarRNN(VarRNNBase): class VarGRU(VarRNNBase): """ Variational Dropout GRU. - + 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ + """ def __init__(self, *args, **kwargs): From 8f0f280629fa6d95a0f2bef57cbe28869584b9c1 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 10:24:04 +0800 Subject: [PATCH 08/16] add doc for NaiveClassifier --- fastNLP/models/base_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index 61edb91f..f1896cb2 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -22,6 +22,9 @@ class BaseModel(torch.nn.Module): class NaiveClassifier(BaseModel): + """ + 一个简单的分类器例子,可用于各种测试 + """ def __init__(self, in_feature_dim, out_feature_dim): super(NaiveClassifier, self).__init__() self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) From 0a4f17f4cee083234fd1ee8bb0affdb33ebb563d Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 11:11:52 +0800 Subject: [PATCH 09/16] add test for ModelSaver & ModelLoader --- test/io/test_model_io.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 test/io/test_model_io.py diff --git a/test/io/test_model_io.py b/test/io/test_model_io.py new file mode 100644 index 00000000..b8960492 --- /dev/null +++ b/test/io/test_model_io.py @@ -0,0 +1,25 @@ +import os +import unittest + +from fastNLP.io import ModelSaver, ModelLoader +from fastNLP.models import CNNText + + +class TestModelIO(unittest.TestCase): + def test_save_and_load(self): + model = CNNText((10, 10), 2) + saver = ModelSaver('tmp') + loader = ModelLoader() + saver.save_pytorch(model) + + new_cnn = CNNText((10, 10), 2) + loader.load_pytorch(new_cnn, 'tmp') + + new_model = loader.load_pytorch_model('tmp') + + for i in range(10): + for j in range(10): + self.assertEqual(model.embed.embed.weight[i, j], new_cnn.embed.embed.weight[i, j]) + self.assertEqual(model.embed.embed.weight[i, j], new_model["embed.embed.weight"][i, j]) + + os.system('rm tmp') From df123bce0e0b514a28de8f55dc1f6f21b1b8ec48 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 11:15:10 +0800 Subject: [PATCH 10/16] add doc for doc_utils.py --- fastNLP/doc_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fastNLP/doc_utils.py b/fastNLP/doc_utils.py index 52e347b9..d5412ff4 100644 --- a/fastNLP/doc_utils.py +++ b/fastNLP/doc_utils.py @@ -1,4 +1,6 @@ -"""undocumented""" +"""undocumented +用于辅助生成 fastNLP 文档的代码 +""" __all__ = [] @@ -15,6 +17,9 @@ def doc_process(m): pass else: module_name = obj.__module__ + + # 识别并标注类和函数在不同层次中的位置 + while 1: defined_m = sys.modules[module_name] if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: @@ -25,6 +30,8 @@ def doc_process(m): if module_name == m.__name__: # print(name, ": not found defined doc.") break + + # 识别并标注基类,只有基类也在 fastNLP 中定义才显示 if inspect.isclass(obj): for base in obj.__bases__: From e28ecb8b335971402c0658bb7bfe72f1a29d1818 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 13:44:19 +0800 Subject: [PATCH 11/16] hide the EchoCallback --- fastNLP/__init__.py | 1 - fastNLP/core/__init__.py | 5 +-- fastNLP/core/callback.py | 7 +++- test/core/test_dist_trainer.py | 67 +++++++++++++++++++--------------- 4 files changed, 45 insertions(+), 35 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index ded83308..1629ab66 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -36,7 +36,6 @@ __all__ = [ "TensorboardCallback", "WarmupCallback", 'SaveModelCallback', - "EchoCallback", "CallbackException", "EarlyStopError", diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index bea80097..0588c9aa 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -48,7 +48,6 @@ __all__ = [ "TensorboardCallback", "WarmupCallback", 'SaveModelCallback', - "EchoCallback", "CallbackException", "EarlyStopError", @@ -77,8 +76,8 @@ __all__ = [ from ._logger import logger from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ - LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ - CallbackException, EarlyStopError + LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ + EarlyStopError from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 36ce2aa5..dca34db5 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -62,7 +62,6 @@ __all__ = [ "TensorboardCallback", "WarmupCallback", "SaveModelCallback", - "EchoCallback", "CallbackException", "EarlyStopError" @@ -710,6 +709,8 @@ class ControlC(Callback): class SmoothValue(object): + """work for LRFinder""" + def __init__(self, beta: float): self.beta, self.n, self.mov_avg = beta, 0, 0 self.smooth = None @@ -1022,6 +1023,10 @@ class EarlyStopError(CallbackException): class EchoCallback(Callback): + """ + 用于测试分布式训练 + + """ def __init__(self, name, out=sys.stdout): super(EchoCallback, self).__init__() self.name = name diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py index 03f613e1..3b53fe50 100644 --- a/test/core/test_dist_trainer.py +++ b/test/core/test_dist_trainer.py @@ -1,33 +1,36 @@ +import os +import shutil +import subprocess import unittest +from argparse import ArgumentParser import numpy as np import torch.cuda + +from fastNLP import AccuracyMetric +from fastNLP import CrossEntropyLoss, BCELoss from fastNLP import DataSet from fastNLP import Instance -from fastNLP import CrossEntropyLoss, BCELoss from fastNLP import SGD +from fastNLP.core.callback import EchoCallback from fastNLP.core.dist_trainer import DistTrainer, get_local_rank from fastNLP.models.base_model import NaiveClassifier -import shutil -import os -import subprocess -from argparse import ArgumentParser -from fastNLP.core.callback import EchoCallback -from fastNLP import AccuracyMetric + def prepare_fake_dataset(): mean = np.array([-3, -3]) cov = np.array([[1, 0], [0, 1]]) class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) - + mean = np.array([3, 3]) cov = np.array([[1, 0], [0, 1]]) class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) - + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] + [Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B]) return data_set + def prepare_fake_dataset2(*args, size=100): ys = np.random.randint(4, size=100, dtype=np.int64) data = {'y': ys} @@ -35,32 +38,35 @@ def prepare_fake_dataset2(*args, size=100): data[arg] = np.random.randn(size, 5) return DataSet(data=data) + def set_rng_seed(seed): np.random.seed(seed) + def prepare_env(): def prepare_fake_dataset(): mean = np.array([-3, -3]) cov = np.array([[1, 0], [0, 1]]) class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) - + mean = np.array([3, 3]) cov = np.array([[1, 0], [0, 1]]) class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) - + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) return data_set - + data_set = prepare_fake_dataset() data_set.set_input("x") data_set.set_target("y") model = NaiveClassifier(2, 1) return data_set, model + class TestDistTrainer(unittest.TestCase): save_path = './save_cp' - + def run1(self): # test distributed training print('local rank', get_local_rank()) @@ -68,9 +74,9 @@ class TestDistTrainer(unittest.TestCase): data_set = prepare_fake_dataset() data_set.set_input("x", flag=True) data_set.set_target("y", flag=True) - + model = NaiveClassifier(2, 2) - + trainer = DistTrainer( model=model, train_data=data_set, optimizer=SGD(lr=0.1), loss=CrossEntropyLoss(pred="predict", target="y"), @@ -82,7 +88,7 @@ class TestDistTrainer(unittest.TestCase): """ if trainer.is_master and os.path.exists(self.save_path): shutil.rmtree(self.save_path) - + def run2(self): # test fp16 with distributed training print('local rank', get_local_rank()) @@ -90,9 +96,9 @@ class TestDistTrainer(unittest.TestCase): data_set = prepare_fake_dataset() data_set.set_input("x", flag=True) data_set.set_target("y", flag=True) - + model = NaiveClassifier(2, 2) - + trainer = DistTrainer( model=model, train_data=data_set, optimizer=SGD(lr=0.1), loss=CrossEntropyLoss(pred="predict", target="y"), @@ -105,7 +111,7 @@ class TestDistTrainer(unittest.TestCase): """ if trainer.is_master and os.path.exists(self.save_path): shutil.rmtree(self.save_path) - + def run3(self): set_rng_seed(100) data_set, model = prepare_env() @@ -117,15 +123,15 @@ class TestDistTrainer(unittest.TestCase): callbacks_master=[EchoCallback('callbacks_master')] ) trainer.train() - + def run4(self): set_rng_seed(100) data_set, model = prepare_env() - + train_set, dev_set = data_set.split(0.3) - + model = NaiveClassifier(2, 1) - + trainer = DistTrainer( train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), @@ -138,7 +144,7 @@ class TestDistTrainer(unittest.TestCase): """ if trainer.is_master and os.path.exists(self.save_path): shutil.rmtree(self.save_path) - + def run_dist(self, run_id): if torch.cuda.is_available(): ngpu = min(2, torch.cuda.device_count()) @@ -147,23 +153,24 @@ class TestDistTrainer(unittest.TestCase): '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] print(' '.join(cmd)) subprocess.check_call(cmd) - + def test_normal_run(self): self.run_dist(1) - + def no_test_fp16(self): self.run_dist(2) - + def test_callback(self): self.run_dist(3) - + def test_dev_data(self): self.run_dist(4) + if __name__ == '__main__': runner = TestDistTrainer() parser = ArgumentParser() parser.add_argument('--test', type=int) args, _ = parser.parse_known_args() - if args.test and hasattr(runner, 'run%s'%args.test): - getattr(runner, 'run%s'%args.test)() + if args.test and hasattr(runner, 'run%s' % args.test): + getattr(runner, 'run%s' % args.test)() From d8fa75b0585c5870bff31fe843c6c4c27b040f23 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 14:38:05 +0800 Subject: [PATCH 12/16] add the test for ControlC callback --- test/core/test_callbacks.py | 98 +++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 32 deletions(-) diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 78f76b65..fc555afb 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -1,39 +1,35 @@ +import os +import tempfile import unittest import numpy as np import torch -import os -import shutil -from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ - LRFinder, TensorboardCallback +from fastNLP import AccuracyMetric +from fastNLP import BCELoss from fastNLP import DataSet from fastNLP import Instance -from fastNLP import BCELoss -from fastNLP import AccuracyMetric from fastNLP import SGD from fastNLP import Trainer -from fastNLP.models.base_model import NaiveClassifier -from fastNLP.core.callback import EarlyStopError +from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ + LRFinder, TensorboardCallback from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback from fastNLP.core.callback import WarmupCallback -import tempfile +from fastNLP.models.base_model import NaiveClassifier + def prepare_env(): - def prepare_fake_dataset(): - mean = np.array([-3, -3]) - cov = np.array([[1, 0], [0, 1]]) - class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) - - mean = np.array([3, 3]) - cov = np.array([[1, 0], [0, 1]]) - class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) - - data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + - [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) - return data_set + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) - data_set = prepare_fake_dataset() data_set.set_input("x") data_set.set_target("y") model = NaiveClassifier(2, 1) @@ -43,11 +39,11 @@ def prepare_env(): class TestCallback(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() - + def tearDown(self): pass # shutil.rmtree(self.tempdir) - + def test_gradient_clip(self): data_set, model = prepare_env() trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), @@ -100,7 +96,7 @@ class TestCallback(unittest.TestCase): path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time)) if os.path.exists(path): shutil.rmtree(path) - + def test_readonly_property(self): from fastNLP.core.callback import Callback passed_epochs = [] @@ -123,19 +119,19 @@ class TestCallback(unittest.TestCase): check_code_level=2) trainer.train() assert passed_epochs == list(range(1, total_epochs + 1)) - + def test_evaluate_callback(self): data_set, model = prepare_env() from fastNLP import Tester tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) evaluate_callback = EvaluateCallback(data_set, tester) - + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, callbacks=evaluate_callback, check_code_level=2) trainer.train() - + def test_fitlog_callback(self): import fitlog fitlog.set_log_dir(self.tempdir) @@ -143,13 +139,13 @@ class TestCallback(unittest.TestCase): from fastNLP import Tester tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) fitlog_callback = FitlogCallback(data_set, tester) - + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=fitlog_callback, check_code_level=2) trainer.train() - + def test_save_model_callback(self): data_set, model = prepare_env() top = 3 @@ -159,10 +155,10 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=save_model_callback, check_code_level=2) trainer.train() - + timestamp = os.listdir(self.tempdir)[0] self.assertEqual(len(os.listdir(os.path.join(self.tempdir, timestamp))), top) - + def test_warmup_callback(self): data_set, model = prepare_env() warmup_callback = WarmupCallback() @@ -171,3 +167,41 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=warmup_callback, check_code_level=2) trainer.train() + + +def test_control_C(): + # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 + from fastNLP import ControlC, Callback + import time + + line1 = "\n\n\n\n\n*************************" + line2 = "*************************\n\n\n\n\n" + + + class Wait(Callback): + def on_epoch_end(self): + time.sleep(5) + + + data_set, model = prepare_env() + + print(line1 + "Test starts!" + line2) + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=32, n_epochs=20, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=[Wait(), ControlC(False)], check_code_level=2) + trainer.train() + + print(line1 + "Program goes on ..." + line2) + + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=32, n_epochs=20, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=[Wait(), ControlC(True)], check_code_level=2) + trainer.train() + + print(line1 + "Test failed!" + line2) + + +if __name__ == "__main__": + test_control_C() \ No newline at end of file From 4f0ec4a08103763b2861901bfdfbcce750e005fd Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 15:27:15 +0800 Subject: [PATCH 13/16] add the test for EarlyStopCallback, BUG found! --- test/core/test_callbacks.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index fc555afb..db95a32d 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -167,6 +167,17 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=warmup_callback, check_code_level=2) trainer.train() + + def test_early_stop_callback(self): + """ + 需要观察是否真的 EarlyStop + """ + data_set, model = prepare_env() + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=2, n_epochs=10, print_every=5, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=EarlyStopCallback(1), check_code_level=2) + trainer.train() def test_control_C(): @@ -177,12 +188,10 @@ def test_control_C(): line1 = "\n\n\n\n\n*************************" line2 = "*************************\n\n\n\n\n" - class Wait(Callback): def on_epoch_end(self): time.sleep(5) - data_set, model = prepare_env() print(line1 + "Test starts!" + line2) @@ -204,4 +213,4 @@ def test_control_C(): if __name__ == "__main__": - test_control_C() \ No newline at end of file + test_control_C() From 243bf8ce425f11fc1743ce2cf16eec8774f11027 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 23 Sep 2019 16:38:13 +0800 Subject: [PATCH 14/16] fix the bug on the trainer --- fastNLP/core/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index c331ab18..47d9edc5 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -829,12 +829,12 @@ class Trainer(object): self.best_metric_indicator = indicator_val else: if self.increase_better is True: - if indicator_val >= self.best_metric_indicator: + if indicator_val > self.best_metric_indicator: self.best_metric_indicator = indicator_val else: is_better = False else: - if indicator_val <= self.best_metric_indicator: + if indicator_val < self.best_metric_indicator: self.best_metric_indicator = indicator_val else: is_better = False From a4babc04e2908d45ad8b6403b36e6eb669b4905d Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Wed, 25 Sep 2019 14:42:58 +0800 Subject: [PATCH 15/16] 1. add summary loader; 2. reorganize code in ExtCNNDMPipe; 3. reorganize test data and code for ExtCNNDMPipe; --- fastNLP/io/loader/summarization.py | 63 +++++++++ fastNLP/io/pipe/summarization.py | 86 ++++++------ test/data_for_tests/io/cnndm/dev.label.jsonl | 4 + test/data_for_tests/io/cnndm/test.label.jsonl | 4 + .../cnndm/train.cnndm.jsonl} | 0 .../{cnndm.vocab => io/cnndm/vocab} | 0 .../{test_extcnndm.py => test_summary.py} | 128 ++++++++++-------- 7 files changed, 188 insertions(+), 97 deletions(-) create mode 100644 fastNLP/io/loader/summarization.py create mode 100644 test/data_for_tests/io/cnndm/dev.label.jsonl create mode 100644 test/data_for_tests/io/cnndm/test.label.jsonl rename test/data_for_tests/{cnndm.jsonl => io/cnndm/train.cnndm.jsonl} (100%) rename test/data_for_tests/{cnndm.vocab => io/cnndm/vocab} (100%) rename test/io/pipe/{test_extcnndm.py => test_summary.py} (52%) diff --git a/fastNLP/io/loader/summarization.py b/fastNLP/io/loader/summarization.py new file mode 100644 index 00000000..95b18af7 --- /dev/null +++ b/fastNLP/io/loader/summarization.py @@ -0,0 +1,63 @@ +"""undocumented""" + +__all__ = [ + "ExtCNNDMLoader" +] + +import os +from typing import Union, Dict + +from ..data_bundle import DataBundle +from ..utils import check_loader_paths +from .json import JsonLoader + + +class ExtCNNDMLoader(JsonLoader): + """ + 读取之后的DataSet中的field情况为 + + .. csv-table:: + :header: "text", "summary", "label", "publication" + + ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" + ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" + ["..."], ["..."], [], "cnndm" + + """ + + def __init__(self, fields=None): + fields = fields or {"text": None, "summary": None, "label": None, "publication": None} + super(ExtCNNDMLoader, self).__init__(fields=fields) + + def load(self, paths: Union[str, Dict[str, str]] = None): + """ + 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 + + :param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl + test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe` + 当中需要用到)。 + + :return: 返回 :class:`~fastNLP.io.DataBundle` + """ + if paths is None: + paths = self.download() + paths = check_loader_paths(paths) + if ('train' in paths) and ('test' not in paths): + paths['test'] = paths['train'] + paths.pop('train') + + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + """ + 如果你使用了这个数据,请引用 + + https://arxiv.org/pdf/1506.03340.pdf + :return: + """ + output_dir = self._get_dataset_path('ext-cnndm') + return output_dir diff --git a/fastNLP/io/pipe/summarization.py b/fastNLP/io/pipe/summarization.py index 9412b1d3..64fa545d 100644 --- a/fastNLP/io/pipe/summarization.py +++ b/fastNLP/io/pipe/summarization.py @@ -1,15 +1,14 @@ """undocumented""" +import os import numpy as np from .pipe import Pipe -from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance -from ..loader.json import JsonLoader +from .utils import _drop_empty_instance +from ..loader.summarization import ExtCNNDMLoader from ..data_bundle import DataBundle -from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader from ...core.const import Const -from ...core.dataset import DataSet -from ...core.instance import Instance from ...core.vocabulary import Vocabulary +from ...core._logger import logger WORD_PAD = "[PAD]" @@ -18,7 +17,6 @@ DOMAIN_UNK = "X" TAG_UNK = "X" - class ExtCNNDMPipe(Pipe): """ 对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: @@ -27,13 +25,13 @@ class ExtCNNDMPipe(Pipe): :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" """ - def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False): + def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): """ :param vocab_size: int, 词表大小 - :param vocab_path: str, 外部词表路径 :param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 :param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 + :param vocab_path: str, 外部词表路径 :param domain: bool, 是否需要建立domain词表 """ self.vocab_size = vocab_size @@ -42,8 +40,7 @@ class ExtCNNDMPipe(Pipe): self.doc_max_timesteps = doc_max_timesteps self.domain = domain - - def process(self, db: DataBundle): + def process(self, data_bundle: DataBundle): """ 传入的DataSet应该具备如下的结构 @@ -64,24 +61,28 @@ class ExtCNNDMPipe(Pipe): [[""],...,[""]], [[],...,[]], [], [] """ - db.apply(lambda x: _lower_text(x['text']), new_field_name='text') - db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') - db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') - db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) + if self.vocab_path is None: + error_msg = 'vocab file is not defined!' + logger.error(error_msg) + raise RuntimeError(error_msg) + data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') + data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') + data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') + data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) - db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) + data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") # pad document - db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) - db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) - db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) + data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) + data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) + data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) - db = _drop_empty_instance(db, "label") + data_bundle = _drop_empty_instance(data_bundle, "label") # set input and target - db.set_input(Const.INPUT, Const.INPUT_LEN) - db.set_target(Const.TARGET, Const.INPUT_LEN) + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET, Const.INPUT_LEN) # print("[INFO] Load existing vocab from %s!" % self.vocab_path) word_list = [] @@ -96,47 +97,52 @@ class ExtCNNDMPipe(Pipe): vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) vocabs.add_word_lst(word_list) vocabs.build_vocab() - db.set_vocab(vocabs, "vocab") + data_bundle.set_vocab(vocabs, "vocab") - if self.domain == True: + if self.domain is True: domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) - domaindict.from_dataset(db.get_dataset("train"), field_name="publication") - db.set_vocab(domaindict, "domain") - - return db + domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") + data_bundle.set_vocab(domaindict, "domain") + return data_bundle def process_from_file(self, paths=None): """ - :param paths: dict or string - :return: DataBundle - """ - db = DataBundle() - if isinstance(paths, dict): - for key, value in paths.items(): - db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key) - else: - db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test') - self.process(db) + :param paths: dict or string + :return: DataBundle + """ + loader = ExtCNNDMLoader() + if self.vocab_path is None: + if paths is None: + paths = loader.download() + if not os.path.isdir(paths): + error_msg = 'vocab file is not defined!' + logger.error(error_msg) + raise RuntimeError(error_msg) + self.vocab_path = os.path.join(paths, 'vocab') + db = loader.load(paths=paths) + db = self.process(db) for ds in db.datasets.values(): db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) return db - def _lower_text(text_list): return [text.lower() for text in text_list] + def _split_list(text_list): return [text.split() for text in text_list] + def _convert_label(label, sent_len): np_label = np.zeros(sent_len, dtype=int) if label != []: np_label[np.array(label)] = 1 return np_label.tolist() + def _pad_sent(text_wd, sent_max_len): pad_text_wd = [] for sent_wd in text_wd: @@ -148,6 +154,7 @@ def _pad_sent(text_wd, sent_max_len): pad_text_wd.append(sent_wd) return pad_text_wd + def _token_mask(text_wd, sent_max_len): token_mask_list = [] for sent_wd in text_wd: @@ -159,6 +166,7 @@ def _token_mask(text_wd, sent_max_len): token_mask_list.append(mask) return token_mask_list + def _pad_label(label, doc_max_timesteps): text_len = len(label) if text_len < doc_max_timesteps: @@ -167,6 +175,7 @@ def _pad_label(label, doc_max_timesteps): pad_label = label[:doc_max_timesteps] return pad_label + def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): text_len = len(text_wd) if text_len < doc_max_timesteps: @@ -176,6 +185,7 @@ def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): pad_text = text_wd[:doc_max_timesteps] return pad_text + def _sent_mask(text_wd, doc_max_timesteps): text_len = len(text_wd) if text_len < doc_max_timesteps: diff --git a/test/data_for_tests/io/cnndm/dev.label.jsonl b/test/data_for_tests/io/cnndm/dev.label.jsonl new file mode 100644 index 00000000..52a56ab0 --- /dev/null +++ b/test/data_for_tests/io/cnndm/dev.label.jsonl @@ -0,0 +1,4 @@ +{"label": [1, 19, 25], "text": ["marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane .", "marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . ''", "he added , `` a person who has such a video needs to immediately give it to the investigators . ''", "robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps .", "all 150 on board were killed .", "paris match and bild reported that the video was recovered from a phone at the wreckage site .", "the two publications described the supposed video , but did not post it on their websites .", "the publications said that they watched the video , which was found by a source close to the investigation .", "`` one can hear cries of ` my god ' in several languages , '' paris match reported .", "`` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object .", "towards the end , after a heavy shake , stronger than the others , the screaming intensifies .", "then nothing . ''", "`` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online .", "an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video .", "lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . ''", "cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . ''", "menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators .", "but none of the cell phones found so far have been sent to the institute , menichini said .", "asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . ''", "reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real .", "he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports .", "`` that is something we did not know before .", "... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said .", "what was mental state of germanwings co-pilot ?", "german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps .", "lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday .", "email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training .", "the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit .", "lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors .", "spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside .", "he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial .", "menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching .", "french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested .", "in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said .", "among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board .", "check out the latest from our correspondents .", "the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet .", "a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . ''", "earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license .", "kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash .", "investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday .", "while flying was `` a big part of his life , '' the source said , it 's only one theory being considered .", "another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems .", "lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said .", "but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist .", "`` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said .", "`` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . ''", "germanwings crash compensation : what we know .", "who was the captain of germanwings flight 9525 ?", "cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london .", "cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report ."], "summary": ["marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports .", "journalists at bild and paris match are `` very confident '' the video clip is real , an editor says .", "andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says ."], "publication": "cnndm", "compression": 22.283333333333335, "coverage": 0.8666666666666667, "density": 4.6} +{"label": [3, 5, 24], "text": ["-lrb- cnn -rrb- the palestinian authority officially became the 123rd member of the international criminal court on wednesday , a step that gives the court jurisdiction over alleged crimes in palestinian territories .", "the formal accession was marked with a ceremony at the hague , in the netherlands , where the court is based .", "the palestinians signed the icc 's founding rome statute in january , when they also accepted its jurisdiction over alleged crimes committed `` in the occupied palestinian territory , including east jerusalem , since june 13 , 2014 . ''", "later that month , the icc opened a preliminary examination into the situation in palestinian territories , paving the way for possible war crimes investigations against israelis .", "as members of the court , palestinians may be subject to counter-charges as well .", "israel and the united states , neither of which is an icc member , opposed the palestinians ' efforts to join the body .", "but palestinian foreign minister riad al-malki , speaking at wednesday 's ceremony , said it was a move toward greater justice .", "`` as palestine formally becomes a state party to the rome statute today , the world is also a step closer to ending a long era of impunity and injustice , '' he said , according to an icc news release .", "`` indeed , today brings us closer to our shared goals of justice and peace . ''", "judge kuniko ozaki , a vice president of the icc , said acceding to the treaty was just the first step for the palestinians .", "`` as the rome statute today enters into force for the state of palestine , palestine acquires all the rights as well as responsibilities that come with being a state party to the statute .", "these are substantive commitments , which can not be taken lightly , '' she said .", "rights group human rights watch welcomed the development .", "`` governments seeking to penalize palestine for joining the icc should immediately end their pressure , and countries that support universal acceptance of the court 's treaty should speak out to welcome its membership , '' said balkees jarrah , international justice counsel for the group .", "`` what 's objectionable is the attempts to undermine international justice , not palestine 's decision to join a treaty to which over 100 countries around the world are members . ''", "in january , when the preliminary icc examination was opened , israeli prime minister benjamin netanyahu described it as an outrage , saying the court was overstepping its boundaries .", "the united states also said it `` strongly '' disagreed with the court 's decision .", "`` as we have said repeatedly , we do not believe that palestine is a state and therefore we do not believe that it is eligible to join the icc , '' the state department said in a statement .", "it urged the warring sides to resolve their differences through direct negotiations .", "`` we will continue to oppose actions against israel at the icc as counterproductive to the cause of peace , '' it said .", "but the icc begs to differ with the definition of a state for its purposes and refers to the territories as `` palestine . ''", "while a preliminary examination is not a formal investigation , it allows the court to review evidence and determine whether to investigate suspects on both sides .", "prosecutor fatou bensouda said her office would `` conduct its analysis in full independence and impartiality . ''", "the war between israel and hamas militants in gaza last summer left more than 2,000 people dead .", "the inquiry will include alleged war crimes committed since june .", "the international criminal court was set up in 2002 to prosecute genocide , crimes against humanity and war crimes .", "cnn 's vasco cotovio , kareem khadder and faith karimi contributed to this report ."], "summary": ["membership gives the icc jurisdiction over alleged crimes committed in palestinian territories since last june .", "israel and the united states opposed the move , which could open the door to war crimes investigations against israelis ."], "publication": "cnndm", "compression": 17.57894736842105, "coverage": 0.8947368421052632, "density": 3.1052631578947367} +{"label": [0, 6], "text": ["-lrb- cnn -rrb- governments around the world are using the threat of terrorism -- real or perceived -- to advance executions , amnesty international alleges in its annual report on the death penalty .", "`` the dark trend of governments using the death penalty in a futile attempt to tackle real or imaginary threats to state security and public safety was stark last year , '' said salil shetty , amnesty 's secretary general in a release .", "`` it is shameful that so many states around the world are essentially playing with people 's lives -- putting people to death for ` terrorism ' or to quell internal instability on the ill-conceived premise of deterrence . ''", "the report , `` death sentences and executions 2014 , '' cites the example of pakistan lifting a six-year moratorium on the execution of civilians following the horrific attack on a school in peshawar in december .", "china is also mentioned , as having used the death penalty as a tool in its `` strike hard '' campaign against terrorism in the restive far-western province of xinjiang .", "the annual report catalogs the use of state-sanctioned killing as a punitive measure across the globe , and this year 's edition contains some mixed findings .", "on one hand , the number of executions worldwide has gone down by almost 22 % on the previous year .", "at least 607 people were executed around the world in 2014 , compared to 778 in 2013 .", "amnesty 's figures do not include statistics on executions carried out in china , where information on the practice is regarded as a state secret .", "belarus and vietnam , too , do not release data on death penalty cases .", "`` the long-term trend is definitely positive -- we are seeing a decrease in the number of executions -lrb- worldwide -rrb- , '' audrey gaughran , amnesty 's director of global issues , told cnn .", "`` a number of countries are closer to abolition , and there are some signs that some countries will be abolitionist by 2015 .", "-lrb- there are -rrb- signals of a world that is nearing abolition . ''", "while the report notes some encouraging signs , it also highlights a marked increase in the number of people sentenced to death in 2014 .", "at least 2,466 people globally are confirmed to have been handed the sentence last year , an increase of 28 % compared with 2013 .", "the report notes that the spike in sentencing is attributable to mass-sentencing in countries including egypt and nigeria , `` against scores of people in some cases . ''", "the organization found `` positive developments '' worldwide , with most regions seeming to show reductions in the number of executions .", "opinion : sharp spike in death sentences .", "sub-saharan africa , for example , saw a 28 % fall in reported cases , and executions recorded in the middle east and north africa were down 23 % compared to 2013 .", "`` even though we 've highlighted some of the negative developments ... i think we would always highlight that there are positive developments , '' gaughran said .", "`` across the board , with the exception of europe and central asia there were fewer reports of executions in every region . ''", "the resumption of the use of capital punishment in belarus -- the only country in europe and central asia to execute people -- after a two year hiatus spoiled an near-universal decrease in countries using the death penalty by region .", "the united states has the dubious distinction of being the only country in the americas to conduct executions , but the number of convicts put to death here fell slightly , from 39 in 2013 to 35 in 2014 .", "the state of washington also imposed a moratorium on executions last year .", "the u.s. remains one of the worst offenders for imposing capital punishment , with only iran -lrb- 289 + -rrb- , iraq -lrb- 61 + -rrb- , and saudi arabia -lrb- 90 + -rrb- executing more people in 2014 .", "while figures are not available , amnesty estimates that china also executes `` thousands '' of prisoners each year , `` more than the rest of the world put together . ''", "the report also highlights the imperfections in the judiciary processes that lead to many sentenced to death .", "`` in the majority of countries where people were sentenced to death or executed , the death penalty was imposed after proceedings that did not meet international fair trial standards , '' the report stated .", "`` in 2014 amnesty international raised particular concerns in relation to court proceedings in afghanistan , bangladesh , china , egypt , iran , iraq , north korea , pakistan , saudi arabia and sri lanka . ''", "the united nations secretary-general , ban ki-moon , last year stressed the need to move toward abolition of capital punishment .", "`` the taking of life is too irreversible for one human being to inflict it on another , '' he said , in marking world day against death penalty in october .", "`` we must continue to argue strongly that the death penalty is unjust and incompatible with fundamental human rights . ''", "amnesty estimates that at least 19,094 people were believed to be on death row at the end of 2014 ."], "summary": ["amnesty 's annual death penalty report catalogs encouraging signs , but setbacks in numbers of those sentenced to death .", "organization claims that governments around the world are using the threat of terrorism to advance executions .", "the number of executions worldwide has gone down by almost 22 % compared with 2013 , but death sentences up by 28 % ."], "publication": "cnndm", "compression": 14.841269841269842, "coverage": 0.8888888888888888, "density": 5.079365079365079} +{"label": [8, 9, 34], "text": ["-lrb- cnn -rrb- on may 28 , 2014 , some 7,000 people gathered in a stadium in china 's northwestern xinjiang region .", "but they had not come to watch the local football team or any other grand sporting event .", "instead , the authorities paraded scores of prisoners dressed in orange jumpsuits .", "armed soldiers guarded the exits .", "in the patently unfair , open air trial that followed , 55 people were found guilty of a range of offenses linked to violent attacks in the region and jailed .", "three were sentenced to death .", "the public mass sentencing was part a china 's `` strike hard '' campaign against unrest in xinjiang , a campaign the government claims was launched to combat `` terrorism '' and `` separatism . ''", "but it was also indicative of a trend that was starkly evident last year around the world -- governments using the death penalty in a misguided , and often cynical , attempt to tackle crime and terrorism .", "today , amnesty international releases its annual review of the death penalty worldwide .", "much of it makes for grim reading .", "in pakistan , the government lifted a six-year moratorium on the execution of civilians in the wake of the horrific taliban attack on a school in peshawar in december .", "more than 60 people have been put to death since , and the government has threatened to send thousands more death row prisoners to the gallows .", "iran and iraq executed people for `` terrorism , '' and other countries expanded the scope of capital crimes in their penal codes .", "in a year when abhorrent summary executions by armed groups were branded on the global consciousness as never before , governments are themselves resorting to more executions in a knee-jerk reaction to terrorism .", "other countries made use of executions in similarly flawed attempts to address -- or appear to address -- crime rates .", "jordan ended an eight-year moratorium in december , putting 11 murder convicts to death , with the government saying it was a move to end a surge in violent crime .", "in indonesia , authorities announced plans to execute mainly drug traffickers to tackle a public safety `` national emergency . ''", "six people have already been executed this year .", "a sharp spike in death sentences recorded in 2014 -- up more than 500 on the previous year -- can also be attributed to governments using the death penalty as a political tool .", "the rise was largely because of developments in egypt and nigeria , where courts imposed hundreds of death sentences in the context of internal political instability or crime and armed conflict .", "the simple fact is that governments using the death penalty to tackle crime and security threats are deceiving themselves or the public or both .", "there is no evidence that the threat of execution is more of a deterrent to crime than a prison sentence , as united nations and other studies have repeatedly confirmed .", "it is high time that world leaders stop using the death penalty as an easy way out when times get tough .", "at amnesty international , we have campaigned for an end to the death penalty for decades .", "thankfully , most of the world now appears to agree with us .", "the numbers speak for themselves .", "in 1945 when the united nations was founded , only eight countries had abolished the death penalty .", "today , 140 states are abolitionist in law or practice .", "last year , we recorded executions in 22 countries , down by almost a half from 20 years ago .", "despite the troubling developments we recorded last year , there was still much good news to be found .", "the number of executions recorded around the world dropped significantly in 2014 compared with the previous year , from 778 to 607 .", "this number does not include china , where more people are put to death than the rest of the world put together , but with death penalty statistics treated as a state secret , the true figure is impossible to determine .", "executions were recorded in only three countries in sub-saharan africa -- equatorial guinea , somalia and sudan -- and the number of people put to death went down by more than a quarter .", "the americas continued to be execution-free , apart from the united states .", "those governments that still execute need to realize that they are on the wrong side of history .", "they must join the vast majority of countries which have dropped the ultimate cruel punishment .", "fighting for an end to the death penalty remains an uphill task , but all of us must try to make the world free of this punishment .", "with determination , i know that we can achieve this goal ."], "summary": ["amnesty international releases its annual review of the death penalty worldwide ; much of it makes for grim reading .", "salil shetty : countries that use executions to deal with problems are on the wrong side of history ."], "publication": "cnndm", "compression": 20.85, "coverage": 0.825, "density": 6.375} diff --git a/test/data_for_tests/io/cnndm/test.label.jsonl b/test/data_for_tests/io/cnndm/test.label.jsonl new file mode 100644 index 00000000..d74ebd9f --- /dev/null +++ b/test/data_for_tests/io/cnndm/test.label.jsonl @@ -0,0 +1,4 @@ +{"label": [2, 3], "text": ["-lrb- cnn -rrb- the rev.", "robert h. schuller , california televangelist and founder of the television ministry `` hour of power , '' died thursday , according to his family .", "he was 88 years old .", "schuller , also the founder of crystal cathedral megachurch , had been diagnosed with esophageal cancer in august 2013 , a release from `` hour of power '' said .", "`` my father-in-law passed away peacefully early this morning .", "he was a great dad and a great man of god , '' said schuller 's daughter-in-law , donna schuller , in a twitter message .", "schuller 's life followed an almost shakespearean arc .", "he was born in a iowa farmhouse without running water and longed to preach from his earliest days .", "in his autobiography , `` prayer : my soul 's adventure with god , '' he described standing alone by a river and picturing himself delivering sermons to a rapt congregation .", "after attending a hope college and western theological seminary in michigan , he met his wife of more than 60 years , arvella , while preaching at her church -lrb- she was the organist -rrb- .", "with their young family in tow , the schullers caravanned west to california , where he rented a drive-in theater and preached from the roof of the snack bar .", "it was beneath the dignity of christian ministry , some local pastors huffed .", "the `` passion pits '' where teenagers necked was no place for the gospel .", "schuller was undeterred , and he quickly outgrew the drive-in .", "he called the explosive growth of his tiny congregation a `` miracle , '' though his many mainstream critics had other names for it .", "his confident , breezy version of christianity -- too breezy , by some estimations -- drew hordes of seekers and lapsed christians who were put off by the hellfire fulminations of many post-war american preachers .", "schuller sold a softer , gentler message , which borrowed heavily , he acknowledged , from the father of the feel-good gospel , norman vincent peale .", "he preached not to convert or condemn people , but to encourage them , a sentiment he called `` possibility thinking . ''", "people loved it .", "`` evangelicalism at its best wants to be innovative and reach people , '' said timothy larsen , a professor of christian thought at wheaton college in illinois .", "`` and schuller was a master at that . ''", "`` what he got right is that the gospel is good news , '' larsen continued .", "`` and he preached an uplifting message about personal transformation and uplift and hope . ''", "some of schuller 's favored phrases , though , struck others as cornpone christianity .", "`` turn your hurt into a halo ? ''", "said randall balmer , a professor of american religious history at dartmouth college , citing one such phrase .", "`` that 's pretty weak tea . ''", "still , balmer gives schuller some credit .", "`` it may be bad theology , but it 's brilliant marketing . ''", "in 1970 , schuller began broadcasting `` hour of power , '' believed to be one of the first , if not the very first , sunday service to be shown regularly on television .", "with his genial smile , priestly robes and gray hair , he looked and talked like a guy who wanted nothing more than to see his flock succeed .", "the show , which ran for decades , reached millions , making schuller a televangelist before the term became tarnished by the sins of his many successors .", "schuller 's crowning achievement , at least architecturally , still stands in orange county , california , though it is now owned by the roman catholic church .", "the crystal cathedral , a great gleaming edifice with 10,000 glass panels , gave worshipers a look at the clouds that house the heavens , while schuller preached in the pulpit below .", "the message was clear to many : the road to the former ran through the latter .", "during the 1980s and 1990s , schuller 's star continued to rise , with presidents stopping by the crystal cathedral -- often during campaigns , it should be said -- and future megachurch pastors like rick warren and bill hybels seeking his advice .", "as schuller aged , though , his family was beset by a succession scandal straight from the pages of `` king lear . ''", "he tried to install his only son , bobby jr. , as pastor of crystal cathedral .", "but the preaching styles of father and son were too different for the congregation -- measured at times at 10,000 strong -- to countenance .", "bobby schuller jr. left `` hour of power '' and the pulpit at crystal cathedral after a short time .", "as the family searched for a new successor and tussled over finances , viewers and donations to the church and its television show dropped precipitously .", "crystal cathedral ministries filed for bankruptcy in 2010 , citing debts of more than $ 43 million , according to the associated press .", "schuller 's empire , which once soared as high as his glassy cathedral , had fallen to dust .", "eventually , schuller 's grandson , also named bobby , took over `` hour of power , '' though at a different church .", "in a statement on thursday , the younger schuller recalled standing atop crystal cathedral 's 12-story tower of hope with his grandfather as they surveyed the surrounding landscape .", "`` you could see the whole world from there , '' he said .", "people we 've lost in 2015 .", "cnn 's stella chan reported from los angeles ."], "summary": ["the rev.", "robert schuller , 88 , had been diagnosed with esophageal cancer in 2013 .", "his tv show , `` hour of power , '' was enormously popular in the 1970s and 1980s ."], "publication": "cnndm", "compression": 26.342105263157894, "coverage": 0.8421052631578947, "density": 3.4210526315789473} +{"label": [4, 6], "text": ["-lrb- cnn -rrb- never mind cats having nine lives .", "a stray pooch in washington state has used up at least three of her own after being hit by a car , apparently whacked on the head with a hammer in a misguided mercy killing and then buried in a field -- only to survive .", "that 's according to washington state university , where the dog -- a friendly white-and-black bully breed mix now named theia -- has been receiving care at the veterinary teaching hospital .", "four days after her apparent death , the dog managed to stagger to a nearby farm , dirt-covered and emaciated , where she was found by a worker who took her to a vet for help .", "she was taken in by moses lake , washington , resident sara mellado .", "`` considering everything that she 's been through , she 's incredibly gentle and loving , '' mellado said , according to wsu news .", "`` she 's a true miracle dog and she deserves a good life . ''", "theia is only one year old but the dog 's brush with death did not leave her unscathed .", "she suffered a dislocated jaw , leg injuries and a caved-in sinus cavity -- and still requires surgery to help her breathe .", "the veterinary hospital 's good samaritan fund committee awarded some money to help pay for the dog 's treatment , but mellado has set up a fundraising page to help meet the remaining cost of the dog 's care .", "she 's also created a facebook page to keep supporters updated .", "donors have already surpassed the $ 10,000 target , inspired by theia 's tale of survival against the odds .", "on the fundraising page , mellado writes , `` she is in desperate need of extensive medical procedures to fix her nasal damage and reset her jaw .", "i agreed to foster her until she finally found a loving home . ''", "she is dedicated to making sure theia gets the medical attention she needs , mellado adds , and wants to `` make sure she gets placed in a family where this will never happen to her again ! ''", "any additional funds raised will be `` paid forward '' to help other animals .", "theia is not the only animal to apparently rise from the grave in recent weeks .", "a cat in tampa , florida , found seemingly dead after he was hit by a car in january , showed up alive in a neighbor 's yard five days after he was buried by his owner .", "the cat was in bad shape , with maggots covering open wounds on his body and a ruined left eye , but remarkably survived with the help of treatment from the humane society ."], "summary": ["theia , a bully breed mix , was apparently hit by a car , whacked with a hammer and buried in a field .", "`` she 's a true miracle dog and she deserves a good life , '' says sara mellado , who is looking for a home for theia ."], "publication": "cnndm", "compression": 9.150943396226415, "coverage": 0.9433962264150944, "density": 4.7924528301886795} +{"label": [32, 36], "text": ["-lrb- cnn -rrb- if you 've been following the news lately , there are certain things you doubtless know about mohammad javad zarif .", "he is , of course , the iranian foreign minister .", "he has been u.s. secretary of state john kerry 's opposite number in securing a breakthrough in nuclear discussions that could lead to an end to sanctions against iran -- if the details can be worked out in the coming weeks .", "and he received a hero 's welcome as he arrived in iran on a sunny friday morning .", "`` long live zarif , '' crowds chanted as his car rolled slowly down the packed street .", "you may well have read that he is `` polished '' and , unusually for one burdened with such weighty issues , `` jovial . ''", "an internet search for `` mohammad javad zarif '' and `` jovial '' yields thousands of results .", "he certainly has gone a long way to bring iran in from the cold and allow it to rejoin the international community .", "but there are some facts about zarif that are less well-known .", "here are six : .", "in september 2013 , zarif tweeted `` happy rosh hashanah , '' referring to the jewish new year .", "that prompted christine pelosi , the daughter of house minority leader nancy pelosi , to respond with a tweet of her own : `` thanks .", "the new year would be even sweeter if you would end iran 's holocaust denial , sir . ''", "and , perhaps to her surprise , pelosi got a response .", "`` iran never denied it , '' zarif tweeted back .", "`` the man who was perceived to be denying it is now gone .", "happy new year . ''", "the reference was likely to former iranian president mahmoud ahmadinejad , who had left office the previous month .", "zarif was nominated to be foreign minister by ahmadinejad 's successor , hassan rouhami .", "his foreign ministry notes , perhaps defensively , that `` due to the political and security conditions of the time , he decided to continue his education in the united states . ''", "that is another way of saying that he was outside the country during the demonstrations against the shah of iran , which began in 1977 , and during the iranian revolution , which drove the shah from power in 1979 .", "zarif left the country in 1977 , received his undergraduate degree from san francisco state university in 1981 , his master 's in international relations from the university of denver in 1984 and his doctorate from the university of denver in 1988 .", "both of his children were born in the united states .", "the website of the iranian foreign ministry , which zarif runs , can not even agree with itself on when he was born .", "the first sentence of his official biography , perhaps in a nod to the powers that be in tehran , says zarif was `` born to a religious traditional family in tehran in 1959 . ''", "later on the same page , however , his date of birth is listed as january 8 , 1960 .", "and the iranian diplomacy website says he was born in in 1961 .", "so he is 54 , 55 or maybe even 56 .", "whichever , he is still considerably younger than his opposite number , kerry , who is 71 .", "the feds investigated him over his alleged role in controlling the alavi foundation , a charitable organization .", "the u.s. justice department said the organization was secretly run on behalf of the iranian government to launder money and get around u.s. sanctions .", "but last year , a settlement in the case , under which the foundation agreed to give a 36-story building in manhattan along with other properties to the u.s. government , did not mention zarif 's name .", "early in the iranian revolution , zarif was among the students who took over the iranian consulate in san francisco .", "the aim , says the website iranian.com -- which cites zarif 's memoirs , titled `` mr. ambassador '' -- was to expel from the consulate people who were not sufficiently islamic .", "later , the website says , zarif went to make a similar protest at the iranian mission to the united nations .", "in response , the iranian ambassador to the united nations offered him a job .", "in fact , he has now spent more time with kerry than any other foreign minister in the world .", "and that amount of quality time will only increase as the two men , with help from other foreign ministers as well , try to meet a june 30 deadline for nailing down the details of the agreement they managed to outline this week in switzerland ."], "summary": ["mohammad javad zarif has spent more time with john kerry than any other foreign minister .", "he once participated in a takeover of the iranian consulate in san francisco .", "the iranian foreign minister tweets in english ."], "publication": "cnndm", "compression": 20.85, "coverage": 0.825, "density": 2.825} +{"label": [2], "text": ["-lrb- cnn -rrb- for the first time in eight years , a tv legend returned to doing what he does best .", "contestants told to `` come on down ! ''", "on the april 1 edition of `` the price is right '' encountered not host drew carey but another familiar face in charge of the proceedings .", "instead , there was bob barker , who hosted the tv game show for 35 years before stepping down in 2007 .", "looking spry at 91 , barker handled the first price-guessing game of the show , the classic `` lucky seven , '' before turning hosting duties over to carey , who finished up .", "despite being away from the show for most of the past eight years , barker did n't seem to miss a beat ."], "summary": ["bob barker returned to host `` the price is right '' on wednesday .", "barker , 91 , had retired as host in 2007 ."], "publication": "cnndm", "compression": 5.346153846153846, "coverage": 0.8076923076923077, "density": 2.5} diff --git a/test/data_for_tests/cnndm.jsonl b/test/data_for_tests/io/cnndm/train.cnndm.jsonl similarity index 100% rename from test/data_for_tests/cnndm.jsonl rename to test/data_for_tests/io/cnndm/train.cnndm.jsonl diff --git a/test/data_for_tests/cnndm.vocab b/test/data_for_tests/io/cnndm/vocab similarity index 100% rename from test/data_for_tests/cnndm.vocab rename to test/data_for_tests/io/cnndm/vocab diff --git a/test/io/pipe/test_extcnndm.py b/test/io/pipe/test_summary.py similarity index 52% rename from test/io/pipe/test_extcnndm.py rename to test/io/pipe/test_summary.py index 1ae3089c..32508a15 100644 --- a/test/io/pipe/test_extcnndm.py +++ b/test/io/pipe/test_summary.py @@ -1,59 +1,69 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -# __author__="Danqing Wang" - -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import unittest -import os -# import sys -# -# sys.path.append("../../../") - -from fastNLP.io import DataBundle -from fastNLP.io.pipe.summarization import ExtCNNDMPipe - -class TestRunExtCNNDMPipe(unittest.TestCase): - - def test_load(self): - data_set_dict = { - 'CNNDM': {"train": 'test/data_for_tests/cnndm.jsonl'}, - } - vocab_size = 100000 - VOCAL_FILE = 'test/data_for_tests/cnndm.vocab' - sent_max_len = 100 - doc_max_timesteps = 50 - dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, - vocab_path=VOCAL_FILE, - sent_max_len=sent_max_len, - doc_max_timesteps=doc_max_timesteps) - dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, - vocab_path=VOCAL_FILE, - sent_max_len=sent_max_len, - doc_max_timesteps=doc_max_timesteps, - domain=True) - for k, v in data_set_dict.items(): - db = dbPipe.process_from_file(v) - db2 = dbPipe2.process_from_file(v) - - # print(db2.get_dataset("train")) - - self.assertTrue(isinstance(db, DataBundle)) - self.assertTrue(isinstance(db2, DataBundle)) - - - - +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# __author__="Danqing Wang" + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest +import os + +from fastNLP.io import DataBundle +from fastNLP.io.pipe.summarization import ExtCNNDMPipe + + +class TestRunExtCNNDMPipe(unittest.TestCase): + + def test_load(self): + data_dir = 'test/data_for_tests/io/cnndm' + vocab_size = 100000 + VOCAL_FILE = 'test/data_for_tests/io/cnndm/vocab' + sent_max_len = 100 + doc_max_timesteps = 50 + dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, + vocab_path=VOCAL_FILE, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps) + dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, + vocab_path=VOCAL_FILE, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps, + domain=True) + db = dbPipe.process_from_file(data_dir) + db2 = dbPipe2.process_from_file(data_dir) + + self.assertTrue(isinstance(db, DataBundle)) + self.assertTrue(isinstance(db2, DataBundle)) + + dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps, + domain=True) + db3 = dbPipe3.process_from_file(data_dir) + self.assertTrue(isinstance(db3, DataBundle)) + + with self.assertRaises(RuntimeError): + dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps) + db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) + + dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size, + vocab_path=VOCAL_FILE, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps,) + db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) + self.assertIsInstance(db5, DataBundle) + From ad957077185f12a662f8fb9a64c1fdc1fa4464f5 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Wed, 25 Sep 2019 14:46:30 +0800 Subject: [PATCH 16/16] 1. reorganize auto download datasets in io/file_utils.py; 2. add auto download for CNNDM and THUCNews; 3. rename XNLI loader and pipe to CNXNLI*; 4. update documents in some download method. --- fastNLP/io/file_utils.py | 24 +++++++-- fastNLP/io/loader/classification.py | 70 +++++--------------------- fastNLP/io/loader/conll.py | 10 ++++ fastNLP/io/loader/coreference.py | 30 +++++++---- fastNLP/io/loader/matching.py | 51 +++++++++++++------ fastNLP/io/pipe/matching.py | 18 +++---- test/io/loader/test_matching_loader.py | 4 +- test/io/pipe/test_matching.py | 6 +-- 8 files changed, 110 insertions(+), 103 deletions(-) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 022af0ac..a4abb575 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -83,27 +83,41 @@ PRETRAIN_STATIC_FILES = { } DATASET_DIR = { + # Classification, English 'aclImdb': "imdb.zip", "yelp-review-full": "yelp_review_full.tar.gz", "yelp-review-polarity": "yelp_review_polarity.tar.gz", + "sst-2": "SST-2.zip", + "sst": "SST.zip", + + # Classification, Chinese + "chn-senti-corp": "chn_senti_corp.zip", + "weibo-senti-100k": "WeiboSenti100k.zip", + "thuc-news": "THUCNews.zip", + + # Matching, English "mnli": "MNLI.zip", "snli": "SNLI.zip", "qnli": "QNLI.zip", - "xnli": "XNLI.zip", - "sst-2": "SST-2.zip", - "sst": "SST.zip", "rte": "RTE.zip", + + # Matching, Chinese + "cn-xnli": "XNLI.zip", + + # Sequence Labeling, Chinese "msra-ner": "MSRA_NER.zip", "peopledaily": "peopledaily.zip", "weibo-ner": "weibo_NER.zip", + # Chinese Word Segmentation "cws-pku": 'cws_pku.zip', "cws-cityu": "cws_cityu.zip", "cws-as": 'cws_as.zip', "cws-msra": 'cws_msra.zip', - "chn-senti-corp" : "chn_senti_corp.zip", - "weibo-senti-100k" : "WeiboSenti100k.zip" + # Summarization, English + "ext-cnndm": "ext-cnndm.zip", + } PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index ca9b6107..004f3ebd 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -373,63 +373,6 @@ class ChnSentiCorpLoader(Loader): """ 从path中读取数据 - :param path: - :return: - """ - ds = DataSet() - with open(path, 'r', encoding='utf-8') as f: - f.readline() - for line in f: - line = line.strip() - tab_index = line.index('\t') - if tab_index!=-1: - target = line[:tab_index] - raw_chars = line[tab_index+1:] - if raw_chars: - ds.append(Instance(raw_chars=raw_chars, target=target)) - return ds - - def download(self)->str: - """ - 自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 - https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 - - :return: - """ - output_dir = self._get_dataset_path('chn-senti-corp') - return output_dir - - -class ChnSentiCorpLoader(Loader): - """ - 支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 - 一个制表符及之后认为是句子 - - Example:: - - label raw_chars - 1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~ - 1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道... - 0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货... - - 读取后的DataSet具有以下的field - - .. csv-table:: - :header: "raw_chars", "target" - - "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" - "<荐书> 推荐所有喜欢<红楼>...", "1" - "..." - - """ - - def __init__(self): - super().__init__() - - def _load(self, path: str): - """ - 从path中读取数据 - :param path: :return: """ @@ -441,7 +384,7 @@ class ChnSentiCorpLoader(Loader): tab_index = line.index('\t') if tab_index != -1: target = line[:tab_index] - raw_chars = line[tab_index + 1:] + raw_chars = line[tab_index+1:] if raw_chars: ds.append(Instance(raw_chars=raw_chars, target=target)) return ds @@ -486,6 +429,17 @@ class THUCNewsLoader(Loader): ds.append(Instance(raw_chars=raw_chars, target=target)) return ds + def download(self) -> str: + """ + 自动下载数据,该数据取自 + + http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews + + :return: + """ + output_dir = self._get_dataset_path('thuc-news') + return output_dir + class WeiboSenti100kLoader(Loader): """ diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py index 97842338..96aefa17 100644 --- a/fastNLP/io/loader/conll.py +++ b/fastNLP/io/loader/conll.py @@ -316,6 +316,16 @@ class CTBLoader(Loader): dataset = self.loader._load(path) return dataset + def download(self): + """ + 由于版权限制,不能提供自动下载功能。可参考 + + https://catalog.ldc.upenn.edu/LDC2013T21 + + :return: + """ + raise RuntimeError("CTB cannot be downloaded automatically.") + class CNNERLoader(Loader): def _load(self, path: str): diff --git a/fastNLP/io/loader/coreference.py b/fastNLP/io/loader/coreference.py index 4293f65a..9f120638 100644 --- a/fastNLP/io/loader/coreference.py +++ b/fastNLP/io/loader/coreference.py @@ -13,23 +13,21 @@ from .json import JsonLoader class CoReferenceLoader(JsonLoader): """ - 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 + 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 - Example:: + Example:: - {"doc_key":"bc/cctv/00/cctv_001", - "speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]", - "clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]", - "sentences":[["I","have","an","apple"],["It","is","good"]] - } + {"doc_key":"bc/cctv/00/cctv_001", + "speakers":"[["Speaker1","Speaker1","Speaker1"],["Speaker1","Speaker1","Speaker1"]]", + "clusters":"[[[2,3],[4,5]],[7,8],[18,20]]]", + "sentences":[["I","have","an","apple"],["It","is","good"]] + } - 读取预处理好的Conll2012数据。 + 读取预处理好的Conll2012数据。 - """ + """ def __init__(self, fields=None, dropna=False): super().__init__(fields, dropna) - # self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1), - # "clusters":Const.TARGET,"sentences":Const.INPUTS(2)} self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), "sentences": Const.RAW_WORDS(3)} @@ -48,3 +46,13 @@ class CoReferenceLoader(JsonLoader): ins = d dataset.append(Instance(**ins)) return dataset + + def download(self): + """ + 由于版权限制,不能提供自动下载功能。可参考 + + https://www.aclweb.org/anthology/W12-4501 + + :return: + """ + raise RuntimeError("CoReference cannot be downloaded automatically.") diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index b9724126..80889507 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -7,7 +7,7 @@ __all__ = [ "RTELoader", "QuoraLoader", "BQCorpusLoader", - "XNLILoader", + "CNXNLILoader", "LCQMCLoader" ] @@ -135,12 +135,12 @@ class SNLILoader(JsonLoader): """ 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 - 读取的field根据ConllLoader初始化时传入的headers决定。 + 读取的field根据Loader初始化时传入的field决定。 :param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl 和snli_1.0_test.jsonl三个文件。 - :return: 返回的:class:`~fastNLP.io.DataBundle` + :return: 返回的 :class:`~fastNLP.io.DataBundle` """ _paths = {} if paths is None: @@ -222,8 +222,7 @@ class QNLILoader(JsonLoader): """ 如果您的实验使用到了该数据,请引用 - .. todo:: - 补充 + https://arxiv.org/pdf/1809.05053.pdf :return: """ @@ -276,6 +275,13 @@ class RTELoader(Loader): return ds def download(self): + """ + 如果您的实验使用到了该数据,请引用GLUE Benchmark + + https://openreview.net/pdf?id=rJ4km2R5t7 + + :return: + """ return self._get_dataset_path('rte') @@ -321,10 +327,17 @@ class QuoraLoader(Loader): return ds def download(self): + """ + 由于版权限制,不能提供自动下载功能。可参考 + + https://www.kaggle.com/c/quora-question-pairs/data + + :return: + """ raise RuntimeError("Quora cannot be downloaded automatically.") -class XNLILoader(Loader): +class CNXNLILoader(Loader): """ 别名: 数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理 @@ -341,7 +354,7 @@ class XNLILoader(Loader): """ def __init__(self): - super(XNLILoader, self).__init__() + super(CNXNLILoader, self).__init__() def _load(self, path: str = None): csv_loader = CSVLoader(sep='\t') @@ -384,7 +397,7 @@ class XNLILoader(Loader): https://arxiv.org/pdf/1809.05053.pdf 有使用 :return: """ - output_dir = self._get_dataset_path('xnli') + output_dir = self._get_dataset_path('cn-xnli') return output_dir @@ -423,6 +436,16 @@ class BQCorpusLoader(Loader): ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) return ds + def download(self): + """ + 由于版权限制,不能提供自动下载功能。可参考 + + https://github.com/ymcui/Chinese-BERT-wwm + + :return: + """ + raise RuntimeError("BQCorpus cannot be downloaded automatically.") + class LCQMCLoader(Loader): """ @@ -461,16 +484,14 @@ class LCQMCLoader(Loader): ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) return ds - ''' - def download(self)->str: + def download(self): """ - 自动下载数据,该数据取自论文 LCQMC: A Large-scale Chinese Question Matching Corpus. - InProceedings of the 27thInternational Conference on Computational Linguistics. 1952–1962. + 由于版权限制,不能提供自动下载功能。可参考 + + https://github.com/ymcui/Chinese-BERT-wwm :return: """ - output_dir = self._get_dataset_path('chn-senti-corp') - return output_dir - ''' + raise RuntimeError("LCQMC cannot be downloaded automatically.") diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 7747dec3..90cf17df 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -7,7 +7,7 @@ __all__ = [ "QuoraBertPipe", "QNLIBertPipe", "MNLIBertPipe", - "XNLIBertPipe", + "CNXNLIBertPipe", "BQCorpusBertPipe", "LCQMCBertPipe", "MatchingPipe", @@ -16,7 +16,7 @@ __all__ = [ "QuoraPipe", "QNLIPipe", "MNLIPipe", - "XNLIPipe", + "CNXNLIPipe", "BQCorpusPipe", "LCQMCPipe", ] @@ -25,7 +25,7 @@ import warnings from .pipe import Pipe from .utils import get_tokenizer -from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, XNLILoader, LCQMCLoader +from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, LCQMCLoader from ...core.const import Const from ...core.vocabulary import Vocabulary from ...core._logger import logger @@ -354,10 +354,10 @@ class LCQMCPipe(MatchingPipe): return data_bundle -class XNLIPipe(MatchingPipe): - def process_from_file(self, paths = None): - data_bundle = XNLILoader().load(paths) - data_bundle = GranularizePipe(task = 'XNLI').process(data_bundle) +class CNXNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = CNXNLILoader().load(paths) + data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe().process(data_bundle) #使中文数据的field data_bundle = self.process(data_bundle) data_bundle = RenamePipe().process(data_bundle) @@ -473,9 +473,9 @@ class BQCorpusBertPipe(MatchingBertPipe): return data_bundle -class XNLIBertPipe(MatchingBertPipe): +class CNXNLIBertPipe(MatchingBertPipe): def process_from_file(self, paths = None): - data_bundle = XNLILoader().load(paths) + data_bundle = CNXNLILoader().load(paths) data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py index 5700ab80..abe21aa9 100644 --- a/test/io/loader/test_matching_loader.py +++ b/test/io/loader/test_matching_loader.py @@ -5,7 +5,7 @@ import os from fastNLP.io import DataBundle from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \ - BQCorpusLoader, XNLILoader, LCQMCLoader + BQCorpusLoader, CNXNLILoader, LCQMCLoader @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") @@ -31,7 +31,7 @@ class TestMatchingLoad(unittest.TestCase): 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), - 'XNLI': ('test/data_for_tests/io/XNLI', XNLILoader, (6, 7, 6), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 7, 6), False), 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), } for k, v in data_set_dict.items(): diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index 6d872692..52d372d5 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -4,9 +4,9 @@ import os from fastNLP.io import DataBundle from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \ - XNLIPipe, BQCorpusPipe, LCQMCPipe + CNXNLIPipe, BQCorpusPipe, LCQMCPipe from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \ - XNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe + CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") @@ -38,7 +38,7 @@ class TestRunMatchingPipe(unittest.TestCase): 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), - 'XNLI': ('test/data_for_tests/io/XNLI', XNLIPipe, XNLIBertPipe, (6, 7, 6), (37, 3), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 7, 6), (37, 3), False), 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), } for k, v in data_set_dict.items():