From 2a80a9de44f719b2cc48ea6a473965f8097ba4f4 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 6 May 2022 14:32:25 +0800 Subject: [PATCH] =?UTF-8?q?Trainer.on=E6=94=AF=E6=8C=81*args=E7=AD=89?= =?UTF-8?q?=E5=BD=A2=E5=BC=8F=E7=9A=84=E8=BE=93=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 1 + fastNLP/core/dataset/dataset.py | 10 ++--- fastNLP/core/log/logger.py | 7 +++- fastNLP/core/utils/utils.py | 42 +++++++++---------- fastNLP/io/data_bundle.py | 16 ++++--- tests/core/controllers/test_trainer.py | 22 ++++++++++ .../test_trainer_w_evaluator_torch.py | 39 +++++++++++++++++ tests/core/utils/test_utils.py | 19 ++++----- 8 files changed, 106 insertions(+), 50 deletions(-) create mode 100644 tests/core/controllers/test_trainer.py diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 8fd3c65e..46eaa175 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -7,6 +7,7 @@ from dataclasses import is_dataclass import os from pathlib import Path import io +import inspect __all__ = [ 'Trainer', diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 98f23286..a861f901 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -402,11 +402,11 @@ class DataSet: r""" 将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 - :param num_proc: 进程的数量 :param field_name: 传入 func 的是哪个 field。 :param func: input是 instance 中名为 `field_name` 的 field 的内容。 :param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 盖之前的 field。如果为 None 则不创建新的 field。 + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param progress_desc: progress_desc 的值,默认为 Main :param show_progress_bar: 是否展示进度条,默认展示进度条 """ @@ -435,10 +435,10 @@ class DataSet: ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 - :param num_proc: 进程的数量 :param field_name: 传入func的是哪个field。 :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param show_progress_bar: 是否显示进度条,默认展示 :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 :return Dict[str:Field]: 返回一个字典 @@ -479,7 +479,7 @@ class DataSet: show_progress_bar: bool = True, _apply_field: str = None, progress_desc: str = 'Main') -> list: """ - :param num_proc: 进程的数量 + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` :param _apply_field: 需要传进去func的数据集的field_name :param show_progress_bar: 是否展示progress进度条,默认为展示 @@ -552,7 +552,7 @@ class DataSet: :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param num_proc: 进程的数量 - :param show_progress_bar: 是否使用tqd显示预处理进度 + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 :return Dict[str:Field]: 返回一个字典 """ @@ -596,7 +596,7 @@ class DataSet: :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 - :param num_proc: 进程的数量。 + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param show_progress_bar: 是否显示进度条。 :param progress_desc: progress bar 显示的值,默认为空。 """ diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index eea54f36..809e9c5c 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -135,8 +135,11 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): :return: """ if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': - if once and msg in self._warning_msgs: - return + if once: + if msg in self._warning_msgs: + return + self._warning_msgs.add(msg) + if self.isEnabledFor(WARNING): kwargs = self._add_rank_info(kwargs) self._log(WARNING, msg, args, **kwargs) diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index c894131d..16ec8238 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -200,29 +200,25 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): if fn_name is not None: assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." - parameters = list(inspect.signature(fn).parameters.values()) - if inspect.ismethod(fn): - if len(parameters)>0 and parameters[0].name == 'self': - parameters = parameters[1:] # 去掉self - - no_var_param = True # 没有 * 这种参数 - number_param_need_value = 0 - for param in parameters: - if param.kind is param.VAR_POSITIONAL: - no_var_param = False - elif param.kind is param.VAR_KEYWORD: - no_var_param = False - else: - if param.default is param.empty: - number_param_need_value += 1 - - if len(parameters)len(expected_params): - raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" - f" {len(expected_params)} parameters:{expected_params} will be provided.") + try: + args = [] + kwargs = {} + name = '' + if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'): + name = 'partial:' + f = fn.func + while isinstance(f, functools.partial): + name += 'partial:' + f = f.func + fn.__name__ = name + f.__name__ + inspect.getcallargs(fn, *args, *expected_params, **kwargs) + if name: # 如果一开始没有name的,需要给人家删除掉 + delattr(fn, '__name__') + + except TypeError as e: + logger.error(f"The function:{_get_fun_msg(fn)} will be provided with parameters:{expected_params}. " + f"The following exception will happen.") + raise e def check_user_specific_params(user_params: Dict, fn: Callable): diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 5a0dc78d..804f7060 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -231,8 +231,8 @@ class DataBundle: 盖之前的field。如果为None则不创建新的field。 :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 - :param ignore_miss_dataset: - :param num_proc: + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 + :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 :param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 :param show_progress_bar 是否显示tqdm进度条 @@ -260,11 +260,11 @@ class DataBundle: :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param str field_name: 传入func的是哪个field。 :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 :param show_progress_bar: 是否显示tqdm进度条 :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 - :param num_proc: :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 @@ -283,7 +283,7 @@ class DataBundle: return res def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, - progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): + progress_desc: str = '', show_progress_bar: bool = True): r""" 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 @@ -292,18 +292,16 @@ class DataBundle: :param callable func: input是instance中名为 `field_name` 的field的内容。 :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 - :param _apply_field: + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param show_progress_bar: 是否显示tqd进度条 :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 - :param num_proc - """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): if _progress_desc: progress_desc = _progress_desc + f' for `{name}`' dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, - progress_desc=progress_desc, _apply_field=_apply_field) + progress_desc=progress_desc) return self def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, @@ -317,9 +315,9 @@ class DataBundle: :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True + :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param show_progress_bar: 是否显示tqd进度条 :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 - :param num_proc :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 """ diff --git a/tests/core/controllers/test_trainer.py b/tests/core/controllers/test_trainer.py new file mode 100644 index 00000000..8788c239 --- /dev/null +++ b/tests/core/controllers/test_trainer.py @@ -0,0 +1,22 @@ +import pytest +from fastNLP import Trainer, Event + + +def test_on(): + with pytest.raises(TypeError): + @Trainer.on(Event.on_before_backward()) + def before_backend(): + pass + + @Trainer.on(Event.on_before_backward()) + def before_backend(*args): + pass + + with pytest.raises(TypeError): + @Trainer.on(Event.on_before_backward()) + def before_backend(*args, s): + pass + + @Trainer.on(Event.on_before_backward()) + def before_backend(*args, s=2): + pass \ No newline at end of file diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index f44bd735..8971b2fe 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -210,4 +210,43 @@ def test_trainer_validate_every( dist.destroy_process_group() +@pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) +@magic_argv_env_context +def test_trainer_on( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + from fastNLP import Event + @Trainer.on(Event.on_before_backward()) + def before_backend(trainer, outputs): + pass + + @Trainer.on(Event.on_before_backward()) + def before_backend_2(*args): + pass + + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=n_epochs, + output_from_new_proc="all", + evaluate_every=-1 + ) + + + + trainer.run() + + + diff --git a/tests/core/utils/test_utils.py b/tests/core/utils/test_utils.py index 556f85ff..9160973d 100644 --- a/tests/core/utils/test_utils.py +++ b/tests/core/utils/test_utils.py @@ -124,9 +124,8 @@ class TestCheckNumberOfParameters: # 无默认值,多了报错 def validate_every(trainer, other): pass - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises(TypeError) as exc_info: _check_valid_parameters_number(validate_every, expected_params=['trainer']) - assert "2 parameters" in exc_info.value.args[0] print(exc_info.value.args[0]) # 有默认值ok @@ -137,19 +136,18 @@ class TestCheckNumberOfParameters: # 参数多了 def validate_every(trainer): pass - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises(TypeError) as exc_info: _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) - assert "accepts 1 parameters" in exc_info.value.args[0] print(exc_info.value.args[0]) # 使用partial def validate_every(trainer, other): pass _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) - _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises(TypeError): + _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) + with pytest.raises(TypeError) as exc_info: _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) - assert 'accepts 2 parameters' in exc_info.value.args[0] print(exc_info.value.args[0]) # 如果存在 *args 或 *kwargs 不报错多的 @@ -159,7 +157,8 @@ class TestCheckNumberOfParameters: def validate_every(trainer, **kwargs): pass - _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) + with pytest.raises(TypeError): + _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) # class 的方法删掉self class InnerClass: @@ -173,10 +172,8 @@ class TestCheckNumberOfParameters: pass inner = InnerClass() - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises(TypeError) as exc_info: _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) - assert 'accepts 1 parameters' in exc_info.value.args[0] - _check_valid_parameters_number(inner.demo, expected_params=['trainer'])