@@ -7,6 +7,7 @@ from dataclasses import is_dataclass | |||
import os | |||
from pathlib import Path | |||
import io | |||
import inspect | |||
__all__ = [ | |||
'Trainer', | |||
@@ -402,11 +402,11 @@ class DataSet: | |||
r""" | |||
将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 | |||
:param num_proc: 进程的数量 | |||
:param field_name: 传入 func 的是哪个 field。 | |||
:param func: input是 instance 中名为 `field_name` 的 field 的内容。 | |||
:param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 | |||
盖之前的 field。如果为 None 则不创建新的 field。 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param progress_desc: progress_desc 的值,默认为 Main | |||
:param show_progress_bar: 是否展示进度条,默认展示进度条 | |||
""" | |||
@@ -435,10 +435,10 @@ class DataSet: | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param num_proc: 进程的数量 | |||
:param field_name: 传入func的是哪个field。 | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示进度条,默认展示 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | |||
:return Dict[str:Field]: 返回一个字典 | |||
@@ -479,7 +479,7 @@ class DataSet: | |||
show_progress_bar: bool = True, _apply_field: str = None, | |||
progress_desc: str = 'Main') -> list: | |||
""" | |||
:param num_proc: 进程的数量 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | |||
:param _apply_field: 需要传进去func的数据集的field_name | |||
:param show_progress_bar: 是否展示progress进度条,默认为展示 | |||
@@ -552,7 +552,7 @@ class DataSet: | |||
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param num_proc: 进程的数量 | |||
:param show_progress_bar: 是否使用tqd显示预处理进度 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
@@ -596,7 +596,7 @@ class DataSet: | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param num_proc: 进程的数量。 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示进度条。 | |||
:param progress_desc: progress bar 显示的值,默认为空。 | |||
""" | |||
@@ -135,8 +135,11 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
:return: | |||
""" | |||
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | |||
if once and msg in self._warning_msgs: | |||
return | |||
if once: | |||
if msg in self._warning_msgs: | |||
return | |||
self._warning_msgs.add(msg) | |||
if self.isEnabledFor(WARNING): | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
@@ -200,29 +200,25 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||
if fn_name is not None: | |||
assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." | |||
parameters = list(inspect.signature(fn).parameters.values()) | |||
if inspect.ismethod(fn): | |||
if len(parameters)>0 and parameters[0].name == 'self': | |||
parameters = parameters[1:] # 去掉self | |||
no_var_param = True # 没有 * 这种参数 | |||
number_param_need_value = 0 | |||
for param in parameters: | |||
if param.kind is param.VAR_POSITIONAL: | |||
no_var_param = False | |||
elif param.kind is param.VAR_KEYWORD: | |||
no_var_param = False | |||
else: | |||
if param.default is param.empty: | |||
number_param_need_value += 1 | |||
if len(parameters)<len(expected_params) and no_var_param: | |||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||
f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||
if number_param_need_value>len(expected_params): | |||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||
f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||
try: | |||
args = [] | |||
kwargs = {} | |||
name = '' | |||
if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'): | |||
name = 'partial:' | |||
f = fn.func | |||
while isinstance(f, functools.partial): | |||
name += 'partial:' | |||
f = f.func | |||
fn.__name__ = name + f.__name__ | |||
inspect.getcallargs(fn, *args, *expected_params, **kwargs) | |||
if name: # 如果一开始没有name的,需要给人家删除掉 | |||
delattr(fn, '__name__') | |||
except TypeError as e: | |||
logger.error(f"The function:{_get_fun_msg(fn)} will be provided with parameters:{expected_params}. " | |||
f"The following exception will happen.") | |||
raise e | |||
def check_user_specific_params(user_params: Dict, fn: Callable): | |||
@@ -231,8 +231,8 @@ class DataBundle: | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param ignore_miss_dataset: | |||
:param num_proc: | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | |||
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
:param show_progress_bar 是否显示tqdm进度条 | |||
@@ -260,11 +260,11 @@ class DataBundle: | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param show_progress_bar: 是否显示tqdm进度条 | |||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
:param num_proc: | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
@@ -283,7 +283,7 @@ class DataBundle: | |||
return res | |||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||
progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||
@@ -292,18 +292,16 @@ class DataBundle: | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param _apply_field: | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示tqd进度条 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
:param num_proc | |||
""" | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | |||
progress_desc=progress_desc, _apply_field=_apply_field) | |||
progress_desc=progress_desc) | |||
return self | |||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
@@ -317,9 +315,9 @@ class DataBundle: | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示tqd进度条 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
:param num_proc | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
""" | |||
@@ -0,0 +1,22 @@ | |||
import pytest | |||
from fastNLP import Trainer, Event | |||
def test_on(): | |||
with pytest.raises(TypeError): | |||
@Trainer.on(Event.on_before_backward()) | |||
def before_backend(): | |||
pass | |||
@Trainer.on(Event.on_before_backward()) | |||
def before_backend(*args): | |||
pass | |||
with pytest.raises(TypeError): | |||
@Trainer.on(Event.on_before_backward()) | |||
def before_backend(*args, s): | |||
pass | |||
@Trainer.on(Event.on_before_backward()) | |||
def before_backend(*args, s=2): | |||
pass |
@@ -210,4 +210,43 @@ def test_trainer_validate_every( | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) | |||
@magic_argv_env_context | |||
def test_trainer_on( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
n_epochs=2, | |||
): | |||
from fastNLP import Event | |||
@Trainer.on(Event.on_before_backward()) | |||
def before_backend(trainer, outputs): | |||
pass | |||
@Trainer.on(Event.on_before_backward()) | |||
def before_backend_2(*args): | |||
pass | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=n_epochs, | |||
output_from_new_proc="all", | |||
evaluate_every=-1 | |||
) | |||
trainer.run() | |||
@@ -124,9 +124,8 @@ class TestCheckNumberOfParameters: | |||
# 无默认值,多了报错 | |||
def validate_every(trainer, other): | |||
pass | |||
with pytest.raises(RuntimeError) as exc_info: | |||
with pytest.raises(TypeError) as exc_info: | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
assert "2 parameters" in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 有默认值ok | |||
@@ -137,19 +136,18 @@ class TestCheckNumberOfParameters: | |||
# 参数多了 | |||
def validate_every(trainer): | |||
pass | |||
with pytest.raises(RuntimeError) as exc_info: | |||
with pytest.raises(TypeError) as exc_info: | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||
assert "accepts 1 parameters" in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 使用partial | |||
def validate_every(trainer, other): | |||
pass | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
with pytest.raises(RuntimeError) as exc_info: | |||
with pytest.raises(TypeError): | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
with pytest.raises(TypeError) as exc_info: | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||
assert 'accepts 2 parameters' in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 如果存在 *args 或 *kwargs 不报错多的 | |||
@@ -159,7 +157,8 @@ class TestCheckNumberOfParameters: | |||
def validate_every(trainer, **kwargs): | |||
pass | |||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
with pytest.raises(TypeError): | |||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
# class 的方法删掉self | |||
class InnerClass: | |||
@@ -173,10 +172,8 @@ class TestCheckNumberOfParameters: | |||
pass | |||
inner = InnerClass() | |||
with pytest.raises(RuntimeError) as exc_info: | |||
with pytest.raises(TypeError) as exc_info: | |||
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||
assert 'accepts 1 parameters' in exc_info.value.args[0] | |||
_check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||