Browse Source

Trainer.on支持*args等形式的输入

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
2a80a9de44
8 changed files with 106 additions and 50 deletions
  1. +1
    -0
      fastNLP/core/controllers/trainer.py
  2. +5
    -5
      fastNLP/core/dataset/dataset.py
  3. +5
    -2
      fastNLP/core/log/logger.py
  4. +19
    -23
      fastNLP/core/utils/utils.py
  5. +7
    -9
      fastNLP/io/data_bundle.py
  6. +22
    -0
      tests/core/controllers/test_trainer.py
  7. +39
    -0
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  8. +8
    -11
      tests/core/utils/test_utils.py

+ 1
- 0
fastNLP/core/controllers/trainer.py View File

@@ -7,6 +7,7 @@ from dataclasses import is_dataclass
import os import os
from pathlib import Path from pathlib import Path
import io import io
import inspect


__all__ = [ __all__ = [
'Trainer', 'Trainer',


+ 5
- 5
fastNLP/core/dataset/dataset.py View File

@@ -402,11 +402,11 @@ class DataSet:
r""" r"""
将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。


:param num_proc: 进程的数量
:param field_name: 传入 func 的是哪个 field。 :param field_name: 传入 func 的是哪个 field。
:param func: input是 instance 中名为 `field_name` 的 field 的内容。 :param func: input是 instance 中名为 `field_name` 的 field 的内容。
:param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 :param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆
盖之前的 field。如果为 None 则不创建新的 field。 盖之前的 field。如果为 None 则不创建新的 field。
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param progress_desc: progress_desc 的值,默认为 Main :param progress_desc: progress_desc 的值,默认为 Main
:param show_progress_bar: 是否展示进度条,默认展示进度条 :param show_progress_bar: 是否展示进度条,默认展示进度条
""" """
@@ -435,10 +435,10 @@ class DataSet:
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。 ``apply`` 区别的介绍。


:param num_proc: 进程的数量
:param field_name: 传入func的是哪个field。 :param field_name: 传入func的是哪个field。
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示进度条,默认展示 :param show_progress_bar: 是否显示进度条,默认展示
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符
:return Dict[str:Field]: 返回一个字典 :return Dict[str:Field]: 返回一个字典
@@ -479,7 +479,7 @@ class DataSet:
show_progress_bar: bool = True, _apply_field: str = None, show_progress_bar: bool = True, _apply_field: str = None,
progress_desc: str = 'Main') -> list: progress_desc: str = 'Main') -> list:
""" """
:param num_proc: 进程的数量
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance``
:param _apply_field: 需要传进去func的数据集的field_name :param _apply_field: 需要传进去func的数据集的field_name
:param show_progress_bar: 是否展示progress进度条,默认为展示 :param show_progress_bar: 是否展示progress进度条,默认为展示
@@ -552,7 +552,7 @@ class DataSet:
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param num_proc: 进程的数量 :param num_proc: 进程的数量
:param show_progress_bar: 是否使用tqd显示预处理进度
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称
:return Dict[str:Field]: 返回一个字典 :return Dict[str:Field]: 返回一个字典
""" """
@@ -596,7 +596,7 @@ class DataSet:
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。 盖之前的field。如果为None则不创建新的field。
:param num_proc: 进程的数量。
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示进度条。 :param show_progress_bar: 是否显示进度条。
:param progress_desc: progress bar 显示的值,默认为空。 :param progress_desc: progress bar 显示的值,默认为空。
""" """


+ 5
- 2
fastNLP/core/log/logger.py View File

@@ -135,8 +135,11 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
:return: :return:
""" """
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0':
if once and msg in self._warning_msgs:
return
if once:
if msg in self._warning_msgs:
return
self._warning_msgs.add(msg)

if self.isEnabledFor(WARNING): if self.isEnabledFor(WARNING):
kwargs = self._add_rank_info(kwargs) kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs) self._log(WARNING, msg, args, **kwargs)


+ 19
- 23
fastNLP/core/utils/utils.py View File

@@ -200,29 +200,25 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
if fn_name is not None: if fn_name is not None:
assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`."


parameters = list(inspect.signature(fn).parameters.values())
if inspect.ismethod(fn):
if len(parameters)>0 and parameters[0].name == 'self':
parameters = parameters[1:] # 去掉self

no_var_param = True # 没有 * 这种参数
number_param_need_value = 0
for param in parameters:
if param.kind is param.VAR_POSITIONAL:
no_var_param = False
elif param.kind is param.VAR_KEYWORD:
no_var_param = False
else:
if param.default is param.empty:
number_param_need_value += 1

if len(parameters)<len(expected_params) and no_var_param:
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, "
f"but {len(expected_params)} parameters:{expected_params} will be provided.")

if number_param_need_value>len(expected_params):
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only"
f" {len(expected_params)} parameters:{expected_params} will be provided.")
try:
args = []
kwargs = {}
name = ''
if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'):
name = 'partial:'
f = fn.func
while isinstance(f, functools.partial):
name += 'partial:'
f = f.func
fn.__name__ = name + f.__name__
inspect.getcallargs(fn, *args, *expected_params, **kwargs)
if name: # 如果一开始没有name的,需要给人家删除掉
delattr(fn, '__name__')

except TypeError as e:
logger.error(f"The function:{_get_fun_msg(fn)} will be provided with parameters:{expected_params}. "
f"The following exception will happen.")
raise e




def check_user_specific_params(user_params: Dict, fn: Callable): def check_user_specific_params(user_params: Dict, fn: Callable):


+ 7
- 9
fastNLP/io/data_bundle.py View File

@@ -231,8 +231,8 @@ class DataBundle:
盖之前的field。如果为None则不创建新的field。 盖之前的field。如果为None则不创建新的field。
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错 如果为False,则报错
:param ignore_miss_dataset:
:param num_proc:
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 :param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
:param show_progress_bar 是否显示tqdm进度条 :param show_progress_bar 是否显示tqdm进度条


@@ -260,11 +260,11 @@ class DataBundle:
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param str field_name: 传入func的是哪个field。 :param str field_name: 传入func的是哪个field。
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错 如果为False,则报错
:param show_progress_bar: 是否显示tqdm进度条 :param show_progress_bar: 是否显示tqdm进度条
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
:param num_proc:


:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字


@@ -283,7 +283,7 @@ class DataBundle:
return res return res


def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, def apply(self, func: Callable, new_field_name: str, num_proc: int = 0,
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None):
progress_desc: str = '', show_progress_bar: bool = True):
r""" r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法


@@ -292,18 +292,16 @@ class DataBundle:
:param callable func: input是instance中名为 `field_name` 的field的内容。 :param callable func: input是instance中名为 `field_name` 的field的内容。
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。 盖之前的field。如果为None则不创建新的field。
:param _apply_field:
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示tqd进度条 :param show_progress_bar: 是否显示tqd进度条
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称
:param num_proc

""" """
_progress_desc = progress_desc _progress_desc = progress_desc
for name, dataset in self.datasets.items(): for name, dataset in self.datasets.items():
if _progress_desc: if _progress_desc:
progress_desc = _progress_desc + f' for `{name}`' progress_desc = _progress_desc + f' for `{name}`'
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar,
progress_desc=progress_desc, _apply_field=_apply_field)
progress_desc=progress_desc)
return self return self


def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0,
@@ -317,9 +315,9 @@ class DataBundle:


:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示tqd进度条 :param show_progress_bar: 是否显示tqd进度条
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称
:param num_proc


:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
""" """


+ 22
- 0
tests/core/controllers/test_trainer.py View File

@@ -0,0 +1,22 @@
import pytest
from fastNLP import Trainer, Event


def test_on():
with pytest.raises(TypeError):
@Trainer.on(Event.on_before_backward())
def before_backend():
pass

@Trainer.on(Event.on_before_backward())
def before_backend(*args):
pass

with pytest.raises(TypeError):
@Trainer.on(Event.on_before_backward())
def before_backend(*args, s):
pass

@Trainer.on(Event.on_before_backward())
def before_backend(*args, s=2):
pass

+ 39
- 0
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -210,4 +210,43 @@ def test_trainer_validate_every(
dist.destroy_process_group() dist.destroy_process_group()




@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_on(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):
from fastNLP import Event
@Trainer.on(Event.on_before_backward())
def before_backend(trainer, outputs):
pass

@Trainer.on(Event.on_before_backward())
def before_backend_2(*args):
pass

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=n_epochs,
output_from_new_proc="all",
evaluate_every=-1
)



trainer.run()






+ 8
- 11
tests/core/utils/test_utils.py View File

@@ -124,9 +124,8 @@ class TestCheckNumberOfParameters:
# 无默认值,多了报错 # 无默认值,多了报错
def validate_every(trainer, other): def validate_every(trainer, other):
pass pass
with pytest.raises(RuntimeError) as exc_info:
with pytest.raises(TypeError) as exc_info:
_check_valid_parameters_number(validate_every, expected_params=['trainer']) _check_valid_parameters_number(validate_every, expected_params=['trainer'])
assert "2 parameters" in exc_info.value.args[0]
print(exc_info.value.args[0]) print(exc_info.value.args[0])


# 有默认值ok # 有默认值ok
@@ -137,19 +136,18 @@ class TestCheckNumberOfParameters:
# 参数多了 # 参数多了
def validate_every(trainer): def validate_every(trainer):
pass pass
with pytest.raises(RuntimeError) as exc_info:
with pytest.raises(TypeError) as exc_info:
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other'])
assert "accepts 1 parameters" in exc_info.value.args[0]
print(exc_info.value.args[0]) print(exc_info.value.args[0])


# 使用partial # 使用partial
def validate_every(trainer, other): def validate_every(trainer, other):
pass pass
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer'])
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other'])
with pytest.raises(RuntimeError) as exc_info:
with pytest.raises(TypeError):
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other'])
with pytest.raises(TypeError) as exc_info:
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more'])
assert 'accepts 2 parameters' in exc_info.value.args[0]
print(exc_info.value.args[0]) print(exc_info.value.args[0])


# 如果存在 *args 或 *kwargs 不报错多的 # 如果存在 *args 或 *kwargs 不报错多的
@@ -159,7 +157,8 @@ class TestCheckNumberOfParameters:


def validate_every(trainer, **kwargs): def validate_every(trainer, **kwargs):
pass pass
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more'])
with pytest.raises(TypeError):
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more'])


# class 的方法删掉self # class 的方法删掉self
class InnerClass: class InnerClass:
@@ -173,10 +172,8 @@ class TestCheckNumberOfParameters:
pass pass


inner = InnerClass() inner = InnerClass()
with pytest.raises(RuntimeError) as exc_info:
with pytest.raises(TypeError) as exc_info:
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more'])
assert 'accepts 1 parameters' in exc_info.value.args[0]

_check_valid_parameters_number(inner.demo, expected_params=['trainer']) _check_valid_parameters_number(inner.demo, expected_params=['trainer'])






Loading…
Cancel
Save