| @@ -7,6 +7,7 @@ from dataclasses import is_dataclass | |||
| import os | |||
| from pathlib import Path | |||
| import io | |||
| import inspect | |||
| __all__ = [ | |||
| 'Trainer', | |||
| @@ -402,11 +402,11 @@ class DataSet: | |||
| r""" | |||
| 将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 | |||
| :param num_proc: 进程的数量 | |||
| :param field_name: 传入 func 的是哪个 field。 | |||
| :param func: input是 instance 中名为 `field_name` 的 field 的内容。 | |||
| :param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 | |||
| 盖之前的 field。如果为 None 则不创建新的 field。 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param progress_desc: progress_desc 的值,默认为 Main | |||
| :param show_progress_bar: 是否展示进度条,默认展示进度条 | |||
| """ | |||
| @@ -435,10 +435,10 @@ class DataSet: | |||
| ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply`` 区别的介绍。 | |||
| :param num_proc: 进程的数量 | |||
| :param field_name: 传入func的是哪个field。 | |||
| :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示进度条,默认展示 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | |||
| :return Dict[str:Field]: 返回一个字典 | |||
| @@ -479,7 +479,7 @@ class DataSet: | |||
| show_progress_bar: bool = True, _apply_field: str = None, | |||
| progress_desc: str = 'Main') -> list: | |||
| """ | |||
| :param num_proc: 进程的数量 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | |||
| :param _apply_field: 需要传进去func的数据集的field_name | |||
| :param show_progress_bar: 是否展示progress进度条,默认为展示 | |||
| @@ -552,7 +552,7 @@ class DataSet: | |||
| :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
| :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param num_proc: 进程的数量 | |||
| :param show_progress_bar: 是否使用tqd显示预处理进度 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | |||
| :return Dict[str:Field]: 返回一个字典 | |||
| """ | |||
| @@ -596,7 +596,7 @@ class DataSet: | |||
| :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
| 盖之前的field。如果为None则不创建新的field。 | |||
| :param num_proc: 进程的数量。 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示进度条。 | |||
| :param progress_desc: progress bar 显示的值,默认为空。 | |||
| """ | |||
| @@ -135,8 +135,11 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
| :return: | |||
| """ | |||
| if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | |||
| if once and msg in self._warning_msgs: | |||
| return | |||
| if once: | |||
| if msg in self._warning_msgs: | |||
| return | |||
| self._warning_msgs.add(msg) | |||
| if self.isEnabledFor(WARNING): | |||
| kwargs = self._add_rank_info(kwargs) | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| @@ -200,29 +200,25 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||
| if fn_name is not None: | |||
| assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." | |||
| parameters = list(inspect.signature(fn).parameters.values()) | |||
| if inspect.ismethod(fn): | |||
| if len(parameters)>0 and parameters[0].name == 'self': | |||
| parameters = parameters[1:] # 去掉self | |||
| no_var_param = True # 没有 * 这种参数 | |||
| number_param_need_value = 0 | |||
| for param in parameters: | |||
| if param.kind is param.VAR_POSITIONAL: | |||
| no_var_param = False | |||
| elif param.kind is param.VAR_KEYWORD: | |||
| no_var_param = False | |||
| else: | |||
| if param.default is param.empty: | |||
| number_param_need_value += 1 | |||
| if len(parameters)<len(expected_params) and no_var_param: | |||
| raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||
| f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||
| if number_param_need_value>len(expected_params): | |||
| raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||
| f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||
| try: | |||
| args = [] | |||
| kwargs = {} | |||
| name = '' | |||
| if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'): | |||
| name = 'partial:' | |||
| f = fn.func | |||
| while isinstance(f, functools.partial): | |||
| name += 'partial:' | |||
| f = f.func | |||
| fn.__name__ = name + f.__name__ | |||
| inspect.getcallargs(fn, *args, *expected_params, **kwargs) | |||
| if name: # 如果一开始没有name的,需要给人家删除掉 | |||
| delattr(fn, '__name__') | |||
| except TypeError as e: | |||
| logger.error(f"The function:{_get_fun_msg(fn)} will be provided with parameters:{expected_params}. " | |||
| f"The following exception will happen.") | |||
| raise e | |||
| def check_user_specific_params(user_params: Dict, fn: Callable): | |||
| @@ -231,8 +231,8 @@ class DataBundle: | |||
| 盖之前的field。如果为None则不创建新的field。 | |||
| :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
| 如果为False,则报错 | |||
| :param ignore_miss_dataset: | |||
| :param num_proc: | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | |||
| :param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
| :param show_progress_bar 是否显示tqdm进度条 | |||
| @@ -260,11 +260,11 @@ class DataBundle: | |||
| :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param str field_name: 传入func的是哪个field。 | |||
| :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
| 如果为False,则报错 | |||
| :param show_progress_bar: 是否显示tqdm进度条 | |||
| :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
| :param num_proc: | |||
| :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
| @@ -283,7 +283,7 @@ class DataBundle: | |||
| return res | |||
| def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
| progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||
| progress_desc: str = '', show_progress_bar: bool = True): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||
| @@ -292,18 +292,16 @@ class DataBundle: | |||
| :param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
| :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
| 盖之前的field。如果为None则不创建新的field。 | |||
| :param _apply_field: | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示tqd进度条 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
| :param num_proc | |||
| """ | |||
| _progress_desc = progress_desc | |||
| for name, dataset in self.datasets.items(): | |||
| if _progress_desc: | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | |||
| progress_desc=progress_desc, _apply_field=_apply_field) | |||
| progress_desc=progress_desc) | |||
| return self | |||
| def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
| @@ -317,9 +315,9 @@ class DataBundle: | |||
| :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示tqd进度条 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
| :param num_proc | |||
| :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
| """ | |||
| @@ -0,0 +1,22 @@ | |||
| import pytest | |||
| from fastNLP import Trainer, Event | |||
| def test_on(): | |||
| with pytest.raises(TypeError): | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(): | |||
| pass | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(*args): | |||
| pass | |||
| with pytest.raises(TypeError): | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(*args, s): | |||
| pass | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(*args, s=2): | |||
| pass | |||
| @@ -210,4 +210,43 @@ def test_trainer_validate_every( | |||
| dist.destroy_process_group() | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) | |||
| @magic_argv_env_context | |||
| def test_trainer_on( | |||
| model_and_optimizers: TrainerParameters, | |||
| driver, | |||
| device, | |||
| n_epochs=2, | |||
| ): | |||
| from fastNLP import Event | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(trainer, outputs): | |||
| pass | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend_2(*args): | |||
| pass | |||
| trainer = Trainer( | |||
| model=model_and_optimizers.model, | |||
| driver=driver, | |||
| device=device, | |||
| optimizers=model_and_optimizers.optimizers, | |||
| train_dataloader=model_and_optimizers.train_dataloader, | |||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
| input_mapping=model_and_optimizers.input_mapping, | |||
| output_mapping=model_and_optimizers.output_mapping, | |||
| metrics=model_and_optimizers.metrics, | |||
| n_epochs=n_epochs, | |||
| output_from_new_proc="all", | |||
| evaluate_every=-1 | |||
| ) | |||
| trainer.run() | |||
| @@ -124,9 +124,8 @@ class TestCheckNumberOfParameters: | |||
| # 无默认值,多了报错 | |||
| def validate_every(trainer, other): | |||
| pass | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
| assert "2 parameters" in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 有默认值ok | |||
| @@ -137,19 +136,18 @@ class TestCheckNumberOfParameters: | |||
| # 参数多了 | |||
| def validate_every(trainer): | |||
| pass | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||
| assert "accepts 1 parameters" in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 使用partial | |||
| def validate_every(trainer, other): | |||
| pass | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError): | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||
| assert 'accepts 2 parameters' in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 如果存在 *args 或 *kwargs 不报错多的 | |||
| @@ -159,7 +157,8 @@ class TestCheckNumberOfParameters: | |||
| def validate_every(trainer, **kwargs): | |||
| pass | |||
| _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
| with pytest.raises(TypeError): | |||
| _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
| # class 的方法删掉self | |||
| class InnerClass: | |||
| @@ -173,10 +172,8 @@ class TestCheckNumberOfParameters: | |||
| pass | |||
| inner = InnerClass() | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||
| assert 'accepts 1 parameters' in exc_info.value.args[0] | |||
| _check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||