@@ -7,6 +7,7 @@ from dataclasses import is_dataclass | |||||
import os | import os | ||||
from pathlib import Path | from pathlib import Path | ||||
import io | import io | ||||
import inspect | |||||
__all__ = [ | __all__ = [ | ||||
'Trainer', | 'Trainer', | ||||
@@ -402,11 +402,11 @@ class DataSet: | |||||
r""" | r""" | ||||
将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 | 将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 | ||||
:param num_proc: 进程的数量 | |||||
:param field_name: 传入 func 的是哪个 field。 | :param field_name: 传入 func 的是哪个 field。 | ||||
:param func: input是 instance 中名为 `field_name` 的 field 的内容。 | :param func: input是 instance 中名为 `field_name` 的 field 的内容。 | ||||
:param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 | :param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 | ||||
盖之前的 field。如果为 None 则不创建新的 field。 | 盖之前的 field。如果为 None 则不创建新的 field。 | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param progress_desc: progress_desc 的值,默认为 Main | :param progress_desc: progress_desc 的值,默认为 Main | ||||
:param show_progress_bar: 是否展示进度条,默认展示进度条 | :param show_progress_bar: 是否展示进度条,默认展示进度条 | ||||
""" | """ | ||||
@@ -435,10 +435,10 @@ class DataSet: | |||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | ||||
``apply`` 区别的介绍。 | ``apply`` 区别的介绍。 | ||||
:param num_proc: 进程的数量 | |||||
:param field_name: 传入func的是哪个field。 | :param field_name: 传入func的是哪个field。 | ||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param show_progress_bar: 是否显示进度条,默认展示 | :param show_progress_bar: 是否显示进度条,默认展示 | ||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | ||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
@@ -479,7 +479,7 @@ class DataSet: | |||||
show_progress_bar: bool = True, _apply_field: str = None, | show_progress_bar: bool = True, _apply_field: str = None, | ||||
progress_desc: str = 'Main') -> list: | progress_desc: str = 'Main') -> list: | ||||
""" | """ | ||||
:param num_proc: 进程的数量 | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | ||||
:param _apply_field: 需要传进去func的数据集的field_name | :param _apply_field: 需要传进去func的数据集的field_name | ||||
:param show_progress_bar: 是否展示progress进度条,默认为展示 | :param show_progress_bar: 是否展示progress进度条,默认为展示 | ||||
@@ -552,7 +552,7 @@ class DataSet: | |||||
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | ||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param num_proc: 进程的数量 | :param num_proc: 进程的数量 | ||||
:param show_progress_bar: 是否使用tqd显示预处理进度 | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | ||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
""" | """ | ||||
@@ -596,7 +596,7 @@ class DataSet: | |||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | ||||
盖之前的field。如果为None则不创建新的field。 | 盖之前的field。如果为None则不创建新的field。 | ||||
:param num_proc: 进程的数量。 | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param show_progress_bar: 是否显示进度条。 | :param show_progress_bar: 是否显示进度条。 | ||||
:param progress_desc: progress bar 显示的值,默认为空。 | :param progress_desc: progress bar 显示的值,默认为空。 | ||||
""" | """ | ||||
@@ -135,8 +135,11 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
:return: | :return: | ||||
""" | """ | ||||
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | ||||
if once and msg in self._warning_msgs: | |||||
return | |||||
if once: | |||||
if msg in self._warning_msgs: | |||||
return | |||||
self._warning_msgs.add(msg) | |||||
if self.isEnabledFor(WARNING): | if self.isEnabledFor(WARNING): | ||||
kwargs = self._add_rank_info(kwargs) | kwargs = self._add_rank_info(kwargs) | ||||
self._log(WARNING, msg, args, **kwargs) | self._log(WARNING, msg, args, **kwargs) | ||||
@@ -200,29 +200,25 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||||
if fn_name is not None: | if fn_name is not None: | ||||
assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." | assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." | ||||
parameters = list(inspect.signature(fn).parameters.values()) | |||||
if inspect.ismethod(fn): | |||||
if len(parameters)>0 and parameters[0].name == 'self': | |||||
parameters = parameters[1:] # 去掉self | |||||
no_var_param = True # 没有 * 这种参数 | |||||
number_param_need_value = 0 | |||||
for param in parameters: | |||||
if param.kind is param.VAR_POSITIONAL: | |||||
no_var_param = False | |||||
elif param.kind is param.VAR_KEYWORD: | |||||
no_var_param = False | |||||
else: | |||||
if param.default is param.empty: | |||||
number_param_need_value += 1 | |||||
if len(parameters)<len(expected_params) and no_var_param: | |||||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||||
f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||||
if number_param_need_value>len(expected_params): | |||||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||||
f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||||
try: | |||||
args = [] | |||||
kwargs = {} | |||||
name = '' | |||||
if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'): | |||||
name = 'partial:' | |||||
f = fn.func | |||||
while isinstance(f, functools.partial): | |||||
name += 'partial:' | |||||
f = f.func | |||||
fn.__name__ = name + f.__name__ | |||||
inspect.getcallargs(fn, *args, *expected_params, **kwargs) | |||||
if name: # 如果一开始没有name的,需要给人家删除掉 | |||||
delattr(fn, '__name__') | |||||
except TypeError as e: | |||||
logger.error(f"The function:{_get_fun_msg(fn)} will be provided with parameters:{expected_params}. " | |||||
f"The following exception will happen.") | |||||
raise e | |||||
def check_user_specific_params(user_params: Dict, fn: Callable): | def check_user_specific_params(user_params: Dict, fn: Callable): | ||||
@@ -231,8 +231,8 @@ class DataBundle: | |||||
盖之前的field。如果为None则不创建新的field。 | 盖之前的field。如果为None则不创建新的field。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param ignore_miss_dataset: | |||||
:param num_proc: | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | |||||
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | :param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | ||||
:param show_progress_bar 是否显示tqdm进度条 | :param show_progress_bar 是否显示tqdm进度条 | ||||
@@ -260,11 +260,11 @@ class DataBundle: | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param str field_name: 传入func的是哪个field。 | :param str field_name: 传入func的是哪个field。 | ||||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param show_progress_bar: 是否显示tqdm进度条 | :param show_progress_bar: 是否显示tqdm进度条 | ||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | ||||
:param num_proc: | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
@@ -283,7 +283,7 @@ class DataBundle: | |||||
return res | return res | ||||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | ||||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||||
progress_desc: str = '', show_progress_bar: bool = True): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | ||||
@@ -292,18 +292,16 @@ class DataBundle: | |||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | :param callable func: input是instance中名为 `field_name` 的field的内容。 | ||||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | ||||
盖之前的field。如果为None则不创建新的field。 | 盖之前的field。如果为None则不创建新的field。 | ||||
:param _apply_field: | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param show_progress_bar: 是否显示tqd进度条 | :param show_progress_bar: 是否显示tqd进度条 | ||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | ||||
:param num_proc | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | if _progress_desc: | ||||
progress_desc = _progress_desc + f' for `{name}`' | progress_desc = _progress_desc + f' for `{name}`' | ||||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | ||||
progress_desc=progress_desc, _apply_field=_apply_field) | |||||
progress_desc=progress_desc) | |||||
return self | return self | ||||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | ||||
@@ -317,9 +315,9 @@ class DataBundle: | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param show_progress_bar: 是否显示tqd进度条 | :param show_progress_bar: 是否显示tqd进度条 | ||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | ||||
:param num_proc | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
""" | """ | ||||
@@ -0,0 +1,22 @@ | |||||
import pytest | |||||
from fastNLP import Trainer, Event | |||||
def test_on(): | |||||
with pytest.raises(TypeError): | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def before_backend(): | |||||
pass | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def before_backend(*args): | |||||
pass | |||||
with pytest.raises(TypeError): | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def before_backend(*args, s): | |||||
pass | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def before_backend(*args, s=2): | |||||
pass |
@@ -210,4 +210,43 @@ def test_trainer_validate_every( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) | |||||
@magic_argv_env_context | |||||
def test_trainer_on( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=2, | |||||
): | |||||
from fastNLP import Event | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def before_backend(trainer, outputs): | |||||
pass | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def before_backend_2(*args): | |||||
pass | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
output_from_new_proc="all", | |||||
evaluate_every=-1 | |||||
) | |||||
trainer.run() | |||||
@@ -124,9 +124,8 @@ class TestCheckNumberOfParameters: | |||||
# 无默认值,多了报错 | # 无默认值,多了报错 | ||||
def validate_every(trainer, other): | def validate_every(trainer, other): | ||||
pass | pass | ||||
with pytest.raises(RuntimeError) as exc_info: | |||||
with pytest.raises(TypeError) as exc_info: | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | _check_valid_parameters_number(validate_every, expected_params=['trainer']) | ||||
assert "2 parameters" in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | print(exc_info.value.args[0]) | ||||
# 有默认值ok | # 有默认值ok | ||||
@@ -137,19 +136,18 @@ class TestCheckNumberOfParameters: | |||||
# 参数多了 | # 参数多了 | ||||
def validate_every(trainer): | def validate_every(trainer): | ||||
pass | pass | ||||
with pytest.raises(RuntimeError) as exc_info: | |||||
with pytest.raises(TypeError) as exc_info: | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | ||||
assert "accepts 1 parameters" in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | print(exc_info.value.args[0]) | ||||
# 使用partial | # 使用partial | ||||
def validate_every(trainer, other): | def validate_every(trainer, other): | ||||
pass | pass | ||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | ||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
with pytest.raises(TypeError): | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||||
with pytest.raises(TypeError) as exc_info: | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | ||||
assert 'accepts 2 parameters' in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | print(exc_info.value.args[0]) | ||||
# 如果存在 *args 或 *kwargs 不报错多的 | # 如果存在 *args 或 *kwargs 不报错多的 | ||||
@@ -159,7 +157,8 @@ class TestCheckNumberOfParameters: | |||||
def validate_every(trainer, **kwargs): | def validate_every(trainer, **kwargs): | ||||
pass | pass | ||||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||||
with pytest.raises(TypeError): | |||||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||||
# class 的方法删掉self | # class 的方法删掉self | ||||
class InnerClass: | class InnerClass: | ||||
@@ -173,10 +172,8 @@ class TestCheckNumberOfParameters: | |||||
pass | pass | ||||
inner = InnerClass() | inner = InnerClass() | ||||
with pytest.raises(RuntimeError) as exc_info: | |||||
with pytest.raises(TypeError) as exc_info: | |||||
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | ||||
assert 'accepts 1 parameters' in exc_info.value.args[0] | |||||
_check_valid_parameters_number(inner.demo, expected_params=['trainer']) | _check_valid_parameters_number(inner.demo, expected_params=['trainer']) | ||||