diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 0b9020fe..902421c8 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -10,6 +10,7 @@ from .utils import _get_monitor_value from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.log import logger from fastNLP.core.utils import apply_to_collection +from fastNLP.core.utils.utils import _check_valid_parameters_number class Callback: @@ -299,7 +300,11 @@ class HasMonitorCallback(Callback): self.must_have_moinitor = must_have_monitor def set_monitor(self, monitor, larger_better): - self.monitor = str(monitor) if monitor is not None else None + if callable(monitor): # 检查是否能够接受一个参数 + _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') + self.monitor = monitor + else: + self.monitor = str(monitor) if monitor is not None else None self.larger_better = bool(larger_better) if larger_better: self.monitor_value = float('-inf') @@ -322,24 +327,33 @@ class HasMonitorCallback(Callback): raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " f"You can set it in the initialization or through Trainer.") - def get_monitor_value(self, results:Dict)->float: + def get_monitor_value(self, results:Dict)->Union[float, None]: """ 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 :param results: - :return: + :return: 如果为 None ,表明此次没有找到合适的monitor """ if len(results)==0: - return 0 + return None # 保证所有的 tensor 都被转换为了 python 特定的类型 results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=results) - if self._real_monitor != use_monitor: # 发生了替换需要打印 - logger.warning( - f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " - f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") + if monitor_value is None: + return monitor_value + # 第一次运行 + if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: + logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " + f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") + # 检测到此次和上次不同。 + elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: + logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " + f"The expected monitor is:`{self.monitor}`, last used monitor is:" + f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " + f"customized monitor function when the evaluation results are varying between validation.") + self._real_monitor = use_monitor return monitor_value @@ -347,10 +361,12 @@ class HasMonitorCallback(Callback): """ 检测 monitor_value 是否是更好的 - :param monitor_value: + :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 :return: """ + if monitor_value is None: + return False better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) if keep_if_better and better: self.monitor_value = monitor_value @@ -364,6 +380,12 @@ class HasMonitorCallback(Callback): :param monitor_value2: :return: """ + if monitor_value1 is None and monitor_value2 is None: + return True + if monitor_value1 is None: + return False + if monitor_value2 is None: + return True better = False if (self.larger_better and monitor_value1 > monitor_value2) or \ (not self.larger_better and monitor_value1 < monitor_value2): diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 82bfe404..a5be2b4c 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -10,8 +10,7 @@ from copy import deepcopy import fastNLP -from .callback import Callback, HasMonitorCallback -from fastNLP.core.callbacks.utils import _get_monitor_value +from .callback import HasMonitorCallback from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir @@ -166,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback): """ if self.save_topk is not None: monitor_value = self.get_monitor_value(results=results) + if monitor_value is None: + return folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ f"-{self._real_monitor}_{monitor_value}" @@ -231,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), + 返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -278,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), + 返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index 602236f7..b1842d43 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -2,17 +2,18 @@ __all__ = [ 'EarlyStopCallback' ] -from typing import Dict +from typing import Dict, Union, Callable from .callback import HasMonitorCallback from fastNLP.core.utils.exceptions import EarlyStopException class EarlyStopCallback(HasMonitorCallback): - def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): + def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): """ - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 + evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 :param patience: 多少次 validate 不没有提升就停止。 """ @@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback): self.patience = patience def on_validate_end(self, trainer, results): - if len(results)==0: - return monitor_value = self.get_monitor_value(results) + if monitor_value is None: + return if self.is_better_monitor_value(monitor_value, keep_if_better=True): self.wait = 0 else: diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 9a4bb65f..e068326b 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -3,7 +3,7 @@ __all__ = [ ] import os -from typing import Optional, Callable +from typing import Optional, Callable, Union from .callback import HasMonitorCallback from io import BytesIO import shutil @@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call class LoadBestModelCallback(HasMonitorCallback): - def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, + def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, model_load_fn:Optional[Callable] = None, delete_after_train:bool = True): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 + evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 @@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback): self.get_monitor_value(sanity_check_res) def on_validate_end(self, trainer, results): - if len(results)==0: - return monitor_value = self.get_monitor_value(results) + if monitor_value is None: + return if self.is_better_monitor_value(monitor_value, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 756d236b..67176387 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -45,6 +45,7 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 + 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ @@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( + 字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py index 2720ba3f..7ece3bb9 100644 --- a/fastNLP/core/callbacks/utils.py +++ b/fastNLP/core/callbacks/utils.py @@ -1,9 +1,10 @@ -from typing import Optional +from typing import Optional, Union from fastNLP.core.log.logger import logger from difflib import SequenceMatcher +from fastNLP.core.utils.utils import _get_fun_msg -def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): +def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): """ 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 匹配。 @@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( :param monitor: :param real_monitor: :param res: - :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value + :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 + 找到对应的 monitor """ if len(res)==0: - return monitor, 0 + return monitor, None + + if callable(monitor): + try: + monitor_value = monitor(res) + except BaseException as e: + logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.") + raise e + return monitor, monitor_value if monitor in res: return monitor, res[monitor] diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index f468dd4c..b6b6de14 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -5,7 +5,7 @@ __all__ = [ from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Callable, Union +from typing import Any, Dict, List, Callable, Union, Tuple from numbers import Number import warnings @@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception): self.field_name = field_name # 标示当前 field 的名称 -def _get_ele_type_and_dim(cell: Any, dim=0): +def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: r""" 识别cell的类别与dimension的数量 @@ -206,7 +206,7 @@ class AutoCollator(Collator): def __init__(self, as_numpy: bool): super(AutoCollator, self).__init__() self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 - self.need_inputs = [] # 需要的 field name + self.need_inputs = set() # 需要的 field name self.field_dtypes = None # 每列数据单元的 dtype 类型 self.field_dims = None # 每列数据单元维度 self.as_numpy = as_numpy @@ -214,10 +214,17 @@ class AutoCollator(Collator): def __call__(self, ins_lst: List[Dict]) -> dict: if len(self.need_inputs) == 0: raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) + # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 + # 第一种情况,设置了 set_input 的值 # 第二种情况, 根据数据的类型的判断是否 padding if self.field_dtypes is None and self.field_dims is None: - self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) + field_dtypes, field_dims = {}, {} + for key, value in ins_lst[0].items(): + if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: + field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) + self.field_dtypes = field_dtypes + self.field_dims = field_dims pack_ins_lst, pad_ins_lst = {field_name: [] for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} @@ -233,13 +240,13 @@ class AutoCollator(Collator): if len(self.pad_field_value.keys()) > 0: # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 - drop_field_names = [] + non_pad_field_names = [] for k, v in self.pad_field_value.items(): if v is None: - drop_field_names.append(k) + non_pad_field_names.append(k) # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) - for field_name in drop_field_names: + for field_name in non_pad_field_names: field_array = pack_ins_lst.pop(field_name) pad_ins_lst[field_name] = np.array(field_array) @@ -269,7 +276,7 @@ class AutoCollator(Collator): def set_input(self, *field_names): for field_name in field_names: - self.need_inputs.append(field_name) + self.need_inputs.add(field_name) def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 2e3678d3..5196f8c7 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -11,11 +11,12 @@ __all__ = [ from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from .loops import Loop, EvaluateBatchLoop -from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ +from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ match_and_substitute_params, f_rich_progress from fastNLP.core.metrics import Metric from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader +from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.core.log import logger @@ -38,11 +39,11 @@ class Evaluator: driver: Union[str, Driver] = 'single', device: Optional[Union[int, List[int], str]] = None, batch_step_fn: Optional[callable] = None, - mode: str = "validate", + mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, - fp16: Optional[bool] = False, + fp16: bool = False, verbose: int = 1, **kwargs ): @@ -92,8 +93,8 @@ class Evaluator: self.device = device self.verbose = verbose - assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ - "two parameters." + if batch_step_fn is not None: + _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') self.batch_step_fn = batch_step_fn self.mode = mode @@ -135,6 +136,7 @@ class Evaluator: if self.progress_bar == 'auto': self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' + self.driver.check_evaluator_mode(self.mode) self.driver.barrier() def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: @@ -154,8 +156,6 @@ class Evaluator: assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." - self.driver.check_evaluator_mode(self.mode) - if self.mode == 'validate': assert self.driver.has_validate_dataloaders() else: @@ -367,9 +367,10 @@ class _MetricsWrapper: raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" f" return a dict from your model or use `output_mapping` to convert it into dict type.") if isinstance(metric, Metric): - auto_param_call(metric.update, outputs, *args) + # 这样在 auto_param_call 报错的时候才清晰。 + auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) elif _is_torchmetrics_metric(metric): - auto_param_call(metric.update, outputs, *args) + auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) elif _is_allennlp_metric(metric): auto_param_call(metric.__call__, outputs, *args) elif _is_paddle_metric(metric): diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e1f31375..66e88827 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -14,6 +14,7 @@ __all__ = [ from .loops import Loop, TrainBatchLoop from .utils import State, TrainerState +from .utils.utils import check_validate_every from .evaluator import Evaluator from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter @@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver -from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext +from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext +from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME @@ -42,7 +44,7 @@ class Trainer(TrainerEventTrigger): validate_dataloaders=None, batch_step_fn: Optional[Callable] = None, validate_batch_step_fn: Optional[Callable] = None, - validate_mode: str = "validate", + validate_mode: Union[str, callable] = 'validate', callbacks: Union[List[Callback], Callback, None] = None, metrics: Optional[dict] = None, validate_every: Optional[Union[int, callable]] = -1, @@ -51,7 +53,7 @@ class Trainer(TrainerEventTrigger): model_wo_auto_param_call: bool = False, accumulation_steps: int = 1, fp16: bool = False, - monitor: str = None, + monitor: Union[str, callable] = None, larger_better: bool = True, marker: Optional[str] = None, **kwargs @@ -90,11 +92,8 @@ class Trainer(TrainerEventTrigger): :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; :param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; - 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 - 中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed - 表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 - 例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch - 之后,每个 epoch 结束进行一次 validate 。 + 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 + 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 @@ -111,7 +110,7 @@ class Trainer(TrainerEventTrigger): :param fp16: 是否开启混合精度训练;默认为 False; :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; :param kwargs: 一些其它的可能需要的参数; @@ -142,10 +141,9 @@ class Trainer(TrainerEventTrigger): self.input_mapping = input_mapping self.output_mapping = output_mapping - assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ - "two parameters." self.batch_step_fn = batch_step_fn if batch_step_fn is not None: + _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) else: self.check_batch_step_fn = lambda *args, **kwargs: ... @@ -221,18 +219,11 @@ class Trainer(TrainerEventTrigger): if metrics is not None and validate_dataloaders is None: raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") - # 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; - # _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; self.evaluator = None self.monitor = monitor self.larger_better = larger_better if metrics is not None and validate_dataloaders is not None: - if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): - raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") - if callable(validate_every): - logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " - "and in this way, the kind of controlling frequency is depending on the 'step'.") - + check_validate_every(validate_every) self.evaluator = Evaluator( model=model, dataloaders=validate_dataloaders, @@ -352,33 +343,32 @@ class Trainer(TrainerEventTrigger): _validate_res: dict = validate_fn() trainer.on_validate_end(_validate_res) - self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) + self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) def step_validate(self): - if self.evaluator is not None: - should_run_validate = False + """ + 在每个 batch 结束后调用,根据设置执行 evaluate 。 + :return: + """ + if self.evaluator is not None: if callable(self.validate_every): if self.validate_every(self): - should_run_validate = True - elif self.validate_every > 0: - if self.global_forward_batches % self.validate_every == 0: - should_run_validate = True - - if should_run_validate: - self.validate_fn() + self.run_evaluate() + elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: + self.run_evaluate() def epoch_validate(self): - if self.evaluator is not None: - should_run_validate = False + """ + 在每个 epoch 结束后调用,根据设置执行 evaluate 。 + :return: + """ + if self.evaluator is not None: if isinstance(self.validate_every, int) and self.validate_every < 0: validate_every = -self.validate_every if self.cur_epoch_idx % validate_every == 0: - should_run_validate = True - - if should_run_validate: - self.validate_fn() + self.run_evaluate() def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): r""" @@ -410,9 +400,7 @@ class Trainer(TrainerEventTrigger): def wrapper(fn: Callable) -> Callable: cls._custom_callbacks[marker].append((event, fn)) callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] - assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ - f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ - f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." + _check_valid_parameters_number(fn, callback_fn_args) return fn return wrapper diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index 0dce0b27..6e0824a1 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -1,8 +1,9 @@ -from collections.abc import Iterator +import inspect from typing import Dict from fastNLP.core.callbacks import CallbackManager from .state import TrainerState +from fastNLP.core.utils.utils import _check_valid_parameters_number class TrainerEventTrigger: @@ -125,5 +126,8 @@ class _TruncatedDataLoader: return getattr(self.dataloader, item) - - +def check_validate_every(validate_every): + if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): + raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") + if callable(validate_every): + _check_valid_parameters_number(validate_every, expected_params=['trainer']) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 5b8ec635..cd887253 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -178,10 +178,11 @@ class DataSet: elif isinstance(idx, slice): if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") - data_set = DataSet() + dataset = DataSet() for field_name, field in self.field_arrays.items(): - data_set.add_field(field_name=field_name, fields=field.content[idx]) - return data_set + dataset.add_field(field_name=field_name, fields=field.content[idx]) + dataset.collate_fns = deepcopy(self.collate_fns) + return dataset elif isinstance(idx, str): if idx not in self: raise KeyError("No such field called {} in DataSet.".format(idx)) @@ -192,6 +193,7 @@ class DataSet: assert isinstance(i, int), "Only int index allowed." instance = self[i] dataset.append(instance) + dataset.collate_fns = deepcopy(self.collate_fns) return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) @@ -674,6 +676,8 @@ class DataSet: dev_set.append(self[idx]) for idx in train_indices: train_set.append(self[idx]) + dev_set.collate_fns = deepcopy(self.collate_fns) + train_set.collate_fns = deepcopy(self.collate_fns) return dev_set, train_set @@ -795,7 +799,7 @@ class DataSet: :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 :return: """ - # TODO 需要去重复 + # TODO 不能为空 for field_name in field_names: self.collate_fns.set_pad_val(field_name, val=val) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index a8ad32e8..411fdf69 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -66,7 +66,7 @@ class JittorDriver(Driver): if mode == "validate": if not hasattr(model, "validate_step"): if hasattr(model, "test_step"): - logger.warning( + logger.warning_once( "Your model does not have 'validate_step' method but has 'test_step' method, but you" "are using 'mode=validate', we are going to use 'test_step' to substitute for" "'validate_step'.") @@ -74,7 +74,7 @@ class JittorDriver(Driver): else: if not hasattr(model, "test_step"): if hasattr(model, "validate_step"): - logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" + logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" "are using 'mode=test', we are going to use 'validate_step' to substitute for" "'test_step'.") diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 4362dcce..931921fd 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -133,7 +133,7 @@ class PaddleDriver(Driver): else: if not hasattr(model, "test_step"): if hasattr(model, "validate_step"): - logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" + logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" "are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" "'test_step'.") diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index ad9e6794..37110577 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -333,10 +333,8 @@ def all_gather_object(object_list, obj, group=None): >>> output ['foo', 12, {1: 2}] """ - if dist._rank_not_in_group(group): + if dist.distributed_c10d._rank_not_in_group(group): return - - input_tensor, local_size = _object_to_tensor(obj) if _TORCH_GREATER_EQUAL_1_8: current_device = torch.device("cpu") is_nccl_backend = _check_for_nccl_backend(group) @@ -345,10 +343,11 @@ def all_gather_object(object_list, obj, group=None): # We cannot simply use my_rank since rank == device is not necessarily # true. current_device = torch.device("cuda", torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) else: current_device = torch.cuda.current_device() + + input_tensor, local_size = _object_to_tensor(obj, device=current_device) + # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = dist.get_world_size(group=group) @@ -379,3 +378,4 @@ def all_gather_object(object_list, obj, group=None): tensor = tensor.cpu() tensor_size = object_size_list[i] object_list[i] = _tensor_to_object(tensor, tensor_size) + return object_list \ No newline at end of file diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index c60d1552..f1e33d5e 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -113,7 +113,7 @@ class TorchDriver(Driver): if mode == "validate": if not hasattr(model, "validate_step"): if hasattr(model, "test_step"): - logger.warning( + logger.warning_once( "Your model does not have 'validate_step' method but has 'test_step' method, but you" "are using 'mode=validate', we are going to use 'test_step' to substitute for" "'validate_step'.") diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 9763ab4a..004bfb16 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -125,9 +125,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): self._warning_msgs.add(msg) def warn(self, msg, *args, **kwargs): - warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) - self.warning(msg, *args, **kwargs) + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) def error(self, msg, *args, **kwargs): """ diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index 0a60e4d7..d1ac1776 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask class Accuracy(Metric): - def __init__(self, backend: Union[str, Backend, None] = 'auto', - aggregate_when_get_metric: bool = True): + def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) @@ -64,7 +63,7 @@ class Accuracy(Metric): warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") else: - raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " + raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " f"{pred.shape[:-1]}, got {target.shape}.") if masks is not None: diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index c3cc2d39..61433e8e 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -23,14 +23,14 @@ __all__ = [ "BucketedBatchSampler", "ReproducibleBatchSampler", - "re_instantiate_sampler", - "conversion_between_reproducible_and_unrepeated_sampler" + "re_instantiate_sampler" ] from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler -from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler +from .utils import re_instantiate_sampler +from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler diff --git a/fastNLP/core/samplers/conversion_utils.py b/fastNLP/core/samplers/conversion_utils.py new file mode 100644 index 00000000..d5d97d0c --- /dev/null +++ b/fastNLP/core/samplers/conversion_utils.py @@ -0,0 +1,33 @@ +from fastNLP.core.samplers import re_instantiate_sampler +from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \ + SortedSampler +from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \ + UnrepeatedSequentialSampler, UnrepeatedSortedSampler + + +def conversion_between_reproducible_and_unrepeated_sampler(sampler): + """ + 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 + ReproducibleSampler, + + :param sampler: + :return: + """ + assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ + "The sampler must be UnrepeatedSampler or ReproducibleSampler" + if isinstance(sampler, UnrepeatedSampler): + if isinstance(sampler, UnrepeatedRandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) + elif isinstance(sampler, UnrepeatedSequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) + elif isinstance(sampler, UnrepeatedSortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) + raise TypeError(f"{sampler.__class__} has no unrepeated version.") + else: + if isinstance(sampler, RandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) + elif isinstance(sampler, SequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) + elif isinstance(sampler, SortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) + raise TypeError(f"{sampler.__class__} has no reproducible version.") \ No newline at end of file diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index d4535bae..1d2c96d9 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -378,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): batch_indices = list(batch_indices[:-1]) rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 - batches = (np.array(batches)[batch_indices]).tolist() + batches = (np.array(batches, dtype=object)[batch_indices]).tolist() if last_batches: batches = batches + last_batches return batches @@ -387,19 +387,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" " consumed. ") - states = { - 'seed': self.seed, - 'epoch': self.epoch, - 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; - 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset), - 'shuffle': self.shuffle, - 'batch_size': self.batch_size, - 'num_batch_per_bucket': self.num_batch_per_bucket, - 'num_replicas': self.num_replicas - } + states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, + 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, + 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, + 'num_replicas': self.num_replicas, + 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} - states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) return states def load_state_dict(self, states: Dict): diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 396e69b2..6ea9cc6b 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -1,3 +1,10 @@ +__all__ = [ + 'ReproducibleSampler', + 'RandomSampler', + "SortedSampler", + "SequentialSampler" +] + from typing import Dict, List, Union import math import os @@ -10,13 +17,6 @@ from fastNLP.envs.env import FASTNLP_DEQUE_SIZE from .utils import NumConsumedSamplesArray -__all__ = [ - 'ReproducibleSampler', - 'RandomSampler', - "SortedSampler", - "SequentialSampler" -] - class ReproducibleSampler: """ diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py index 80af1787..ddcff37f 100644 --- a/fastNLP/core/samplers/utils.py +++ b/fastNLP/core/samplers/utils.py @@ -1,42 +1,10 @@ __all__ = [ - 're_instantiate_sampler', - 'conversion_between_reproducible_and_unrepeated_sampler' + 're_instantiate_sampler' ] from array import array from typing import Sequence from collections import deque -from fastNLP.core.samplers.unrepeated_sampler import * -from fastNLP.core.samplers.reproducible_sampler import * - - -def conversion_between_reproducible_and_unrepeated_sampler(sampler): - """ - 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 - ReproducibleSampler, - - :param sampler: - :return: - """ - assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ - "The sampler must be UnrepeatedSampler or ReproducibleSampler" - if isinstance(sampler, UnrepeatedSampler): - if isinstance(sampler, UnrepeatedRandomSampler): - return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) - elif isinstance(sampler, UnrepeatedSequentialSampler): - return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) - elif isinstance(sampler, UnrepeatedSortedSampler): - return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) - raise TypeError(f"{sampler.__class__} has no unrepeated version.") - else: - if isinstance(sampler, RandomSampler): - return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) - elif isinstance(sampler, SequentialSampler): - return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) - elif isinstance(sampler, SortedSampler): - return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) - raise TypeError(f"{sampler.__class__} has no reproducible version.") - def re_instantiate_sampler(sampler, new_sampler_class=None): all_attributes = vars(sampler) diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 1d1c9d16..cceb948f 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -13,7 +13,6 @@ __all__ = [ 'torch_paddle_move_data_to_device', 'torch_move_data_to_device', 'get_fn_arg_names', - 'check_fn_not_empty_params', 'auto_param_call', 'check_user_specific_params', 'dataclass_to_dict', @@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi from .rich_progress import f_rich_progress from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device -from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ +from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index d593f4ee..7af6557f 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -1,3 +1,4 @@ +import functools import inspect from inspect import Parameter import dataclasses @@ -24,10 +25,8 @@ from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_GLOBAL_RANK - __all__ = [ 'get_fn_arg_names', - 'check_fn_not_empty_params', 'auto_param_call', 'check_user_specific_params', 'dataclass_to_dict', @@ -54,30 +53,6 @@ def get_fn_arg_names(fn: Callable) -> List[str]: return list(inspect.signature(fn).parameters) -def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: - r""" - 检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; - 用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; - - :param fn: 传入的用以代替 Loop 中 'step' 函数的函数; - :param param_num: 检测的函数的应当的没有默认值的参数的个数; - - :return: bool,表示传入的 `batch_step_fn` 是否正确; - """ - - if fn is None: - return True - if not callable(fn): - return False - else: - params = inspect.signature(fn).parameters - not_default_params = {} - for _name, _param in params.items(): - if _param.default == Parameter.empty: - not_default_params[_name] = _param - return len(not_default_params) == param_num - - def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: r""" @@ -95,7 +70,6 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 参数值后,再传给 `fn` 进行实际的运算; :param mapping: 一个字典,用来更改其前面的字典的键值; - :param wo_auto_param_call: 是否关闭默认的参数匹配行为; :return: 返回 `fn` 运行的结果; @@ -123,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None _kwargs = None for _name, _param in _need_params.items(): if _param.kind == Parameter.VAR_POSITIONAL: - raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.") if _param.kind == Parameter.VAR_KEYWORD: _kwargs = (_name, _param) @@ -136,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None _default_params[_name] = _param.default if mapping is not None: - assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \ + f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." _has_params = {} duplicate_names = [] for arg in args: - assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." + if not isinstance(arg, Dict): + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + raise TypeError(f"Exception happens when calling {fn_msg}. " + f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.") for _name, _value in arg.items(): if mapping is not None and _name in mapping: _name = mapping[_name] @@ -153,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None elif _name in _need_params and not (_has_params[_name] is _value): duplicate_names.append(_name) if duplicate_names: - raise ValueError(f"The following key present in several inputs:{duplicate_names}") + fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) + raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.") # 将具有默认值但是没有被输入修改过的参数值传进去; for _name, _value in _default_params.items(): @@ -162,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None if len(_has_params) List[List[str]]: + """ + 返回每个 dict 的 keys + + :param args: + :return: + """ + _provided_keys = [] + for arg in args: + _provided_keys.append(list(arg.keys())) + return _provided_keys + + +def _get_fun_msg(fn)->str: + """ + 获取函数的基本信息,帮助报错。 + ex: + print(_get_fun_msg(_get_fun_msg)) + # `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) + + :param callable fn: + :return: + """ + if isinstance(fn, functools.partial): + return _get_fun_msg(fn.func) + try: + fn_name = fn.__qualname__ + str(inspect.signature(fn)) + except: + fn_name = str(fn) + try: + fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' + except: + fp = '' + msg = f'`{fn_name}`' + fp + return msg + + +def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): + """ + 检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 + 进行报错。 + + :param fn: 需要检测的函数,可以是 method 或者 function 。 + :param expected_params: 期待应该支持的参数。 + :param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。 + :return: + """ + if fn_name is not None: + assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." + + parameters = list(inspect.signature(fn).parameters.values()) + if inspect.ismethod(fn): + if len(parameters)>0 and parameters[0].name == 'self': + parameters = parameters[1:] # 去掉self + + no_var_param = True # 没有 * 这种参数 + number_param_need_value = 0 + for param in parameters: + if param.kind is param.VAR_POSITIONAL: + no_var_param = False + elif param.kind is param.VAR_KEYWORD: + no_var_param = False + else: + if param.default is param.empty: + number_param_need_value += 1 + + if len(parameters)len(expected_params): + raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" + f" {len(expected_params)} parameters:{expected_params} will be provided.") + + def check_user_specific_params(user_params: Dict, fn: Callable): """ 该函数使用用户的输入来对指定函数的参数进行赋值; @@ -592,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): wait_to_success(path.exists) +def get_class_that_defined_method(method): + """ + 给定一个method,返回这个 method 的 class 的对象 + :param method: + :return: + """ + if isinstance(method, functools.partial): + return get_class_that_defined_method(method.func) + if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)): + for cls in inspect.getmro(method.__self__.__class__): + if method.__name__ in cls.__dict__: + return cls + method = getattr(method, '__func__', method) # fallback to __qualname__ parsing + if inspect.isfunction(method): + cls = getattr(inspect.getmodule(method), + method.__qualname__.split('.', 1)[0].rsplit('.', 1)[0], + None) + if isinstance(cls, type): + return cls + return getattr(method, '__objclass__', None) # handle special descriptor objects \ No newline at end of file diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 5daee519..2796bb69 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -251,10 +251,10 @@ class DataBundle: def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 .. note:: - ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 @@ -285,7 +285,7 @@ class DataBundle: def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 对DataBundle中所有的dataset使用apply方法 @@ -309,10 +309,10 @@ class DataBundle: def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 .. note:: - ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 0e5915a9..b5db4bd6 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): def process_from_file(self, paths) -> DataBundle: r""" - 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` :param paths: :return: DataBundle diff --git a/fastNLP/io/pipe/construct_graph.py b/fastNLP/io/pipe/construct_graph.py index 1448765e..1b6d192a 100644 --- a/fastNLP/io/pipe/construct_graph.py +++ b/fastNLP/io/pipe/construct_graph.py @@ -164,7 +164,7 @@ class GraphBuilderBase: def build_graph_from_file(self, path: str): r""" - 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` :param path: :return: scipy_sparse_matrix diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py index 4916bf09..f974b548 100644 --- a/fastNLP/io/pipe/pipe.py +++ b/fastNLP/io/pipe/pipe.py @@ -33,7 +33,7 @@ class Pipe: def process_from_file(self, paths: str) -> DataBundle: r""" - 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` :param str paths: :return: DataBundle diff --git a/tests/core/utils/test_utils.py b/tests/core/utils/test_utils.py new file mode 100644 index 00000000..a7aeffb1 --- /dev/null +++ b/tests/core/utils/test_utils.py @@ -0,0 +1,187 @@ +from functools import partial + +import pytest + +from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg +from fastNLP.core.metrics import Metric + + + +class TestAutoParamCall: + def test_basic(self): + def fn(x): + return x + x = {'x': 3, 'y': 4} + r = auto_param_call(fn, x) + assert r==3 + + xs = [] + for i in range(10): + xs.append({f'x{i}': i}) + def fn(x0, x1, x2, x3): + return x0 + x1 + x2 + x3 + r = auto_param_call(fn, *xs) + assert r == 0 + 1+ 2+ 3 + + def fn(chongfu1, chongfu2, buChongFu): + pass + with pytest.raises(BaseException) as exc_info: + auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2}, + {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) + assert 'The following key present in several inputs' in exc_info.value.args[0] + assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0] + + # 没用到不报错 + def fn(chongfu1, buChongFu): + pass + auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2}, + {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) + + # 可以定制signature_fn + def fn1(**kwargs): + kwargs.pop('x') + kwargs.pop('y') + assert len(kwargs)==0 + def fn(x, y): + pass + x = {'x': 3, 'y': 4} + r = auto_param_call(fn1, x, signature_fn=fn) + + # 没提供的时候报错 + def fn(meiti1, meiti2, tigong): + pass + with pytest.raises(BaseException) as exc_info: + auto_param_call(fn, {'tigong':1}) + assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0] + + # 默认值替换 + def fn(x, y=100): + return x + y + r = auto_param_call(fn, {'x': 10, 'y': 20}) + assert r==30 + assert auto_param_call(fn, {'x': 10, 'z': 20})==110 + + # 测试mapping的使用 + def fn(x, y=100): + return x + y + r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'}) + assert r==30 + + # 测试不需要任何参数 + def fn(): + return 1 + assert 1 == auto_param_call(fn, {'x':1}) + + # 测试调用类的方法没问题 + assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1}) + assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'}) + + def test_msg(self): + with pytest.raises(BaseException) as exc_info: + auto_param_call(self.call_this, {'x':1}) + assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] + + with pytest.raises(BaseException) as exc_info: + auto_param_call(call_this_for_auto_param_call, {'x':1}) + assert __file__ in exc_info.value.args[0] + assert 'call_this_for_auto_param_call' in exc_info.value.args[0] + + with pytest.raises(BaseException) as exc_info: + auto_param_call(self.call_this_two, {'x':1}) + assert __file__ in exc_info.value.args[0] + + with pytest.raises(BaseException) as exc_info: + auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this) + assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息 + + def call_this(self, x, y): + return x + y + + def call_this_two(self, x, y, z=pytest, **kwargs): + return x + y + + def test_metric_auto_param_call(self): + metric = AutoParamCallMetric() + with pytest.raises(BaseException): + auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__) + + +class AutoParamCallMetric(Metric): + def update(self, x): + pass + + +def call_this_for_auto_param_call(x, y): + return x + y + + +class TestCheckNumberOfParameters: + def test_validate_every(self): + def validate_every(trainer): + pass + _check_valid_parameters_number(validate_every, expected_params=['trainer']) + + # 无默认值,多了报错 + def validate_every(trainer, other): + pass + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(validate_every, expected_params=['trainer']) + assert "2 parameters" in exc_info.value.args[0] + print(exc_info.value.args[0]) + + # 有默认值ok + def validate_every(trainer, other=1): + pass + _check_valid_parameters_number(validate_every, expected_params=['trainer']) + + # 参数多了 + def validate_every(trainer): + pass + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) + assert "accepts 1 parameters" in exc_info.value.args[0] + print(exc_info.value.args[0]) + + # 使用partial + def validate_every(trainer, other): + pass + _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) + _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) + assert 'accepts 2 parameters' in exc_info.value.args[0] + print(exc_info.value.args[0]) + + # 如果存在 *args 或 *kwargs 不报错多的 + def validate_every(trainer, *args): + pass + _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more']) + + def validate_every(trainer, **kwargs): + pass + _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) + + # class 的方法删掉self + class InnerClass: + def demo(self, x): + pass + + def no_param(self): + pass + + def param_kwargs(self, **kwargs): + pass + + inner = InnerClass() + with pytest.raises(RuntimeError) as exc_info: + _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) + assert 'accepts 1 parameters' in exc_info.value.args[0] + + _check_valid_parameters_number(inner.demo, expected_params=['trainer']) + + +def test_get_fun_msg(): + def demo(x): + pass + + print(_get_fun_msg(_get_fun_msg)) \ No newline at end of file diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index b876c289..9a4af07c 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -2,37 +2,19 @@ import os import sys import __main__ from functools import wraps -import inspect from inspect import ismethod -import functools from copy import deepcopy from io import StringIO import time import numpy as np +from fastNLP.core.utils.utils import get_class_that_defined_method from fastNLP.envs.env import FASTNLP_GLOBAL_RANK from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.log import logger -def get_class_that_defined_method(meth): - if isinstance(meth, functools.partial): - return get_class_that_defined_method(meth.func) - if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): - for cls in inspect.getmro(meth.__self__.__class__): - if meth.__name__ in cls.__dict__: - return cls - meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing - if inspect.isfunction(meth): - cls = getattr(inspect.getmodule(meth), - meth.__qualname__.split('.', 1)[0].rsplit('.', 1)[0], - None) - if isinstance(cls, type): - return cls - return getattr(meth, '__objclass__', None) # handle special descriptor objects - - def recover_logger(fn): @wraps(fn) def wrapper(*args, **kwargs):