From ff504377027aeaa685224c04ae7467787ab478d8 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 23 May 2022 19:43:29 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE=E5=A4=84?= =?UTF-8?q?=E7=90=86=E7=9A=84=E6=97=B6=E5=80=99=EF=BC=8C=E5=A4=9A=E8=BF=9B?= =?UTF-8?q?=E7=A8=8Bprint=E4=BC=9A=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98;2.Trainer/Evaluator=E5=A2=9E=E5=8A=A0check=5Fdataload?= =?UTF-8?q?er=5Flegality?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 22 ++++++++-- fastNLP/core/controllers/trainer.py | 12 +++++- fastNLP/core/dataloaders/__init__.py | 3 +- .../dataloaders/torch_dataloader/__init__.py | 4 +- .../{ => torch_dataloader}/mix_dataloader.py | 0 fastNLP/core/dataset/dataset.py | 41 +++++++++++++++---- fastNLP/core/drivers/driver.py | 8 ++++ .../drivers/jittor_driver/jittor_driver.py | 19 +++------ .../drivers/paddle_driver/paddle_driver.py | 26 ++---------- .../core/drivers/torch_driver/torch_driver.py | 23 ++--------- fastNLP/core/utils/dummy_class.py | 6 +++ fastNLP/core/utils/tqdm_progress.py | 9 ++-- fastNLP/envs/utils.py | 19 +++++++++ fastNLP/io/data_bundle.py | 39 ++++++++++++++---- tests/core/dataset/test_dataset.py | 9 ++++ 15 files changed, 158 insertions(+), 82 deletions(-) rename fastNLP/core/dataloaders/{ => torch_dataloader}/mix_dataloader.py (100%) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 3a584b4b..bafabbe9 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -111,6 +111,7 @@ class Evaluator: 分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; * *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; * *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; + * *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 """ @@ -134,6 +135,8 @@ class Evaluator: self.device = device self.verbose = verbose + self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) + if evaluate_batch_step_fn is not None: _check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn') self.evaluate_batch_step_fn = evaluate_batch_step_fn @@ -141,10 +144,23 @@ class Evaluator: self.input_mapping = input_mapping self.output_mapping = output_mapping + # check dataloader if not isinstance(dataloaders, dict): + if kwargs.get('check_dataloader_legality', True): + try: + self.driver.check_dataloader_legality(dataloader=dataloaders) + except TypeError as e: + logger.error("`dataloaders` is invalid.") + raise e dataloaders = {None: dataloaders} - - self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) + else: + if kwargs.get('check_dataloader_legality', True): + for key, dataloader in dataloaders.items(): + try: + self.driver.check_dataloader_legality(dataloader=dataloader) + except TypeError as e: + logger.error(f"The dataloader named:{key} is invalid.") + raise e self.driver.setup() self.driver.barrier() @@ -333,7 +349,7 @@ class Evaluator: @evaluate_batch_loop.setter def evaluate_batch_loop(self, loop: Loop): - if self.evaluate_batch_step_fn is not None: + if getattr(self, 'evaluate_step_fn', None) is not None: logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " "when the `evaluate_batch_loop` is also customized.") self._evaluate_batch_loop = loop diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 79cc36a0..076f674b 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -304,6 +304,7 @@ class Trainer(TrainerEventTrigger): * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 * *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Evaluator`` 中。与 output_mapping 互斥。 + * *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 .. note:: ``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; @@ -463,6 +464,14 @@ class Trainer(TrainerEventTrigger): self.driver.setup() self.driver.barrier() + # check train_dataloader + if kwargs.get('check_dataloader_legality', True): + try: + self.driver.check_dataloader_legality(dataloader=train_dataloader) + except TypeError as e: + logger.error("`train_dataloader` is invalid.") + raise e + use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) if use_dist_sampler: _dist_sampler = "dist" @@ -482,7 +491,8 @@ class Trainer(TrainerEventTrigger): evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), - progress_bar=progress_bar) + progress_bar=progress_bar, + check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) if train_fn is not None and not isinstance(train_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index 976788d9..84f8b288 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -10,8 +10,7 @@ __all__ = [ "prepare_dataloader" ] -from .mix_dataloader import MixDataLoader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader -from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader +from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader from .prepare_dataloader import prepare_dataloader \ No newline at end of file diff --git a/fastNLP/core/dataloaders/torch_dataloader/__init__.py b/fastNLP/core/dataloaders/torch_dataloader/__init__.py index 4f3fc707..a55d3d0d 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/__init__.py +++ b/fastNLP/core/dataloaders/torch_dataloader/__init__.py @@ -1,6 +1,8 @@ __all__ = [ "TorchDataLoader", - "prepare_torch_dataloader" + "prepare_torch_dataloader", + "MixDataLoader" ] from .fdl import TorchDataLoader, prepare_torch_dataloader +from .mix_dataloader import MixDataLoader diff --git a/fastNLP/core/dataloaders/mix_dataloader.py b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py similarity index 100% rename from fastNLP/core/dataloaders/mix_dataloader.py rename to fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 025d33e5..d3622803 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -168,6 +168,7 @@ from fastNLP.core.collators import Collator from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress from fastNLP.core.utils.tqdm_progress import f_tqdm_progress from ..log import logger +from fastNLP.core.utils.dummy_class import DummyClass progress_bars = { @@ -231,7 +232,8 @@ def _multi_proc(ds, _apply_field, func, counter, queue): """ idx = -1 import contextlib - with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁 + null = DummyClass() + with contextlib.redirect_stdout(null): # 避免打印触发 rich 的锁 logger.set_stdout(stdout='raw') results = [] try: @@ -597,7 +599,8 @@ class DataSet: .. note:: - 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。 + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 :param progress_desc: 进度条的描述字符,默认为 ``Processing``; :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 @@ -631,8 +634,14 @@ class DataSet: :param field_name: 传入func的是哪个field。 :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 :return Dict[str:Field]: 返回一个字典 """ @@ -672,7 +681,13 @@ class DataSet: progress_bar: str = 'rich', _apply_field: str = None, progress_desc: str = 'Main') -> list: """ - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` :param _apply_field: 需要传进去func的数据集的field_name :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 @@ -744,7 +759,13 @@ class DataSet: :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + :param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :return Dict[str:Field]: 返回一个字典 @@ -789,7 +810,13 @@ class DataSet: :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :param progress_desc: progress bar 显示的值,默认为空。 """ diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 1b6f2931..bd06e705 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -175,6 +175,14 @@ class Driver(ABC): raise NotImplementedError( "Each specific driver should implemented its own `_check_optimizer_legality` function.") + def check_dataloader_legality(self, dataloader): + """ + 检测 DataLoader 是否合法,如果不合法,会 raise TypeError 。 + + :param dataloder: + :return: + """ + def set_optimizers(self, optimizers=None): r""" trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上; diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 14cb2b84..4f7f23bd 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -13,6 +13,7 @@ if _NEED_IMPORT_JITTOR: import jittor as jt from jittor import Module from jittor.optim import Optimizer + from jittor.dataset import Dataset _reduces = { 'max': jt.max, @@ -52,21 +53,11 @@ class JittorDriver(Driver): # 用来设置是否关闭 auto_param_call 中的参数匹配问题; self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) - @staticmethod - def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): + def check_dataloader_legality(self, dataloader): # 在fastnlp中实现了JittorDataLoader - # TODO: 是否允许传入Dataset? - if is_train: - if not isinstance(dataloader, JittorDataLoader): - raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.") - else: - if not isinstance(dataloader, Dict): - raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") - else: - for each_dataloader in dataloader.values(): - if not isinstance(each_dataloader, JittorDataLoader): - raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' " - f"type, not {type(each_dataloader)}.") + if not isinstance(dataloader, Dataset): + raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`") + @staticmethod def _check_optimizer_legality(optimizers): diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index aec20ac6..e879dd90 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -94,29 +94,9 @@ class PaddleDriver(Driver): self.grad_scaler.step(optimizer) self.grad_scaler.update() - @staticmethod - def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): - if is_train: - if not isinstance(dataloader, DataLoader): - raise ValueError(f"Parameter `{dataloader_name}` should be 'paddle.io.DataLoader' type, not {type(dataloader)}.") - # TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; - if isinstance(dataloader.dataset, IterableDataset): - raise TypeError("`IterableDataset` is not allowed.") - if dataloader.batch_sampler is None and dataloader.batch_size is None: - raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") - else: - if not isinstance(dataloader, Dict): - raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") - else: - for each_dataloader in dataloader.values(): - if not isinstance(each_dataloader, DataLoader): - raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'paddle.io.DataLoader' " - f"type, not {type(each_dataloader)}.") - if isinstance(each_dataloader.dataset, IterableDataset): - raise TypeError("`IterableDataset` is not allowed.") - if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None: - raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " - f"`batch_sampler` and `batch_size` should be set.") + def check_dataloader_legality(self, dataloader): + if not isinstance(dataloader, DataLoader): + raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") @staticmethod def _check_optimizer_legality(optimizers): diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index a0c562f7..21325b5c 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -91,26 +91,9 @@ class TorchDriver(Driver): self.grad_scaler.step(optimizer) self.grad_scaler.update() - @staticmethod - def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): - if is_train: - if not isinstance(dataloader, DataLoader): - raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") - - # todo 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; - if isinstance(dataloader.dataset, IterableDataset): - raise TypeError("`IterableDataset` is not allowed.") - - else: - if not isinstance(dataloader, Dict): - raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") - else: - for each_dataloader in dataloader.values(): - if not isinstance(each_dataloader, DataLoader): - raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'DataLoader' " - f"type, not {type(each_dataloader)}.") - if isinstance(each_dataloader.dataset, IterableDataset): - raise TypeError("`IterableDataset` is not allowed.") + def check_dataloader_legality(self, dataloader): + if not isinstance(dataloader, DataLoader): + raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") @staticmethod def _check_optimizer_legality(optimizers): diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py index e7596607..afd610ce 100644 --- a/fastNLP/core/utils/dummy_class.py +++ b/fastNLP/core/utils/dummy_class.py @@ -3,3 +3,9 @@ __all__ = [] class DummyClass: def __init__(self, *args, **kwargs): pass + + def __getattr__(self, item): + return lambda *args, **kwargs: ... + + def __call__(self, *args, **kwargs): + pass \ No newline at end of file diff --git a/fastNLP/core/utils/tqdm_progress.py b/fastNLP/core/utils/tqdm_progress.py index 897c4b7d..d40cb81f 100644 --- a/fastNLP/core/utils/tqdm_progress.py +++ b/fastNLP/core/utils/tqdm_progress.py @@ -4,7 +4,8 @@ __all__ = [ import uuid import sys -from ...envs.imports import _module_available, _compare_version +from ...envs.utils import _module_available, _compare_version, _get_version + from ...envs import get_global_rank from .utils import is_notebook from ..log import logger @@ -82,8 +83,10 @@ class TqdmProgress(metaclass=Singleton): :param kwargs: :return: """ - assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ - f"To use tqdm, tqdm>=4.57 is needed." + if not _module_available('tqdm'): + raise ModuleNotFoundError("Package tqdm is not installed.") + elif not _compare_version('tqdm', operator.ge, '4.57'): + raise RuntimeError(f"Package tqdm>=4.57 is needed, instead of {_get_version('tqdm')}.") from .rich_progress import f_rich_progress assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." diff --git a/fastNLP/envs/utils.py b/fastNLP/envs/utils.py index 3936771e..541bfba7 100644 --- a/fastNLP/envs/utils.py +++ b/fastNLP/envs/utils.py @@ -26,6 +26,25 @@ def _module_available(module_path: str) -> bool: return False +def _get_version(package, use_base_version: bool = False): + try: + pkg = importlib.import_module(package) + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + if hasattr(pkg, "__version__"): + pkg_version = Version(pkg.__version__) + else: + # try pkg_resources to infer version + pkg_version = Version(pkg_resources.get_distribution(package).version) + except TypeError: + # this is mocked by Sphinx, so it should return True to generate all summaries + return True + if use_base_version: + pkg_version = Version(pkg_version.base_version) + return pkg_version + + def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: """Compare package version with some requirements. diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 81ff3b84..58538d61 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -231,7 +231,13 @@ class DataBundle: 盖之前的field。如果为None则不创建新的field。 :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 :param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称 :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 @@ -260,10 +266,16 @@ class DataBundle: :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param str field_name: 传入func的是哪个field。 :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 @@ -292,8 +304,14 @@ class DataBundle: :param callable func: input是instance中名为 `field_name` 的field的内容。 :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称 """ @@ -316,8 +334,14 @@ class DataBundle: :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param num_proc: 使用进程的数量。 + + .. note:: + + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, + ``func`` 函数中的打印将不会输出。 + + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 @@ -382,4 +406,3 @@ class DataBundle: for name, vocab in self.vocabs.items(): _str += '\t{} has {} entries.\n'.format(name, len(vocab)) return _str - diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index f8458a0c..8fd0e726 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -4,6 +4,7 @@ import pytest import numpy as np from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException +from fastNLP import logger class TestDataSetInit: @@ -379,6 +380,14 @@ class TestDataSetMethods: data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=0) + def test_apply_more_proc(self): + def func(x): + print("x") + logger.info("demo") + return len(x) + data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) + data.apply_field(func, field_name='x', new_field_name='len_x', num_proc=2) + class TestFieldArrayInit: """