| @@ -10,6 +10,7 @@ from .utils import _get_monitor_value | |||
| from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.utils import apply_to_collection | |||
| from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
| class Callback: | |||
| @@ -299,7 +300,11 @@ class HasMonitorCallback(Callback): | |||
| self.must_have_moinitor = must_have_monitor | |||
| def set_monitor(self, monitor, larger_better): | |||
| self.monitor = str(monitor) if monitor is not None else None | |||
| if callable(monitor): # 检查是否能够接受一个参数 | |||
| _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') | |||
| self.monitor = monitor | |||
| else: | |||
| self.monitor = str(monitor) if monitor is not None else None | |||
| self.larger_better = bool(larger_better) | |||
| if larger_better: | |||
| self.monitor_value = float('-inf') | |||
| @@ -322,24 +327,33 @@ class HasMonitorCallback(Callback): | |||
| raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||
| f"You can set it in the initialization or through Trainer.") | |||
| def get_monitor_value(self, results:Dict)->float: | |||
| def get_monitor_value(self, results:Dict)->Union[float, None]: | |||
| """ | |||
| 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||
| :param results: | |||
| :return: | |||
| :return: 如果为 None ,表明此次没有找到合适的monitor | |||
| """ | |||
| if len(results)==0: | |||
| return 0 | |||
| return None | |||
| # 保证所有的 tensor 都被转换为了 python 特定的类型 | |||
| results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||
| use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
| real_monitor=self._real_monitor, | |||
| res=results) | |||
| if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||
| logger.warning( | |||
| f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
| f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
| if monitor_value is None: | |||
| return monitor_value | |||
| # 第一次运行 | |||
| if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||
| logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
| f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||
| # 检测到此次和上次不同。 | |||
| elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||
| logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
| f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||
| f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||
| f"customized monitor function when the evaluation results are varying between validation.") | |||
| self._real_monitor = use_monitor | |||
| return monitor_value | |||
| @@ -347,10 +361,12 @@ class HasMonitorCallback(Callback): | |||
| """ | |||
| 检测 monitor_value 是否是更好的 | |||
| :param monitor_value: | |||
| :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||
| :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||
| :return: | |||
| """ | |||
| if monitor_value is None: | |||
| return False | |||
| better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||
| if keep_if_better and better: | |||
| self.monitor_value = monitor_value | |||
| @@ -364,6 +380,12 @@ class HasMonitorCallback(Callback): | |||
| :param monitor_value2: | |||
| :return: | |||
| """ | |||
| if monitor_value1 is None and monitor_value2 is None: | |||
| return True | |||
| if monitor_value1 is None: | |||
| return False | |||
| if monitor_value2 is None: | |||
| return True | |||
| better = False | |||
| if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||
| (not self.larger_better and monitor_value1 < monitor_value2): | |||
| @@ -10,8 +10,7 @@ from copy import deepcopy | |||
| import fastNLP | |||
| from .callback import Callback, HasMonitorCallback | |||
| from fastNLP.core.callbacks.utils import _get_monitor_value | |||
| from .callback import HasMonitorCallback | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
| from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
| @@ -166,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback): | |||
| """ | |||
| if self.save_topk is not None: | |||
| monitor_value = self.get_monitor_value(results=results) | |||
| if monitor_value is None: | |||
| return | |||
| folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||
| f"-{self._real_monitor}_{monitor_value}" | |||
| @@ -231,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
| 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
| :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
| 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
| 返回一个 float 值作为 monitor 的结果。 | |||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
| @@ -278,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
| 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
| :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
| 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
| 返回一个 float 值作为 monitor 的结果。 | |||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
| @@ -2,17 +2,18 @@ __all__ = [ | |||
| 'EarlyStopCallback' | |||
| ] | |||
| from typing import Dict | |||
| from typing import Dict, Union, Callable | |||
| from .callback import HasMonitorCallback | |||
| from fastNLP.core.utils.exceptions import EarlyStopException | |||
| class EarlyStopCallback(HasMonitorCallback): | |||
| def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||
| def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||
| """ | |||
| :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
| :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
| evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: monitor 的值是否是越大越好。 | |||
| :param patience: 多少次 validate 不没有提升就停止。 | |||
| """ | |||
| @@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||
| self.patience = patience | |||
| def on_validate_end(self, trainer, results): | |||
| if len(results)==0: | |||
| return | |||
| monitor_value = self.get_monitor_value(results) | |||
| if monitor_value is None: | |||
| return | |||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
| self.wait = 0 | |||
| else: | |||
| @@ -3,7 +3,7 @@ __all__ = [ | |||
| ] | |||
| import os | |||
| from typing import Optional, Callable | |||
| from typing import Optional, Callable, Union | |||
| from .callback import HasMonitorCallback | |||
| from io import BytesIO | |||
| import shutil | |||
| @@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call | |||
| class LoadBestModelCallback(HasMonitorCallback): | |||
| def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||
| def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | |||
| save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | |||
| model_load_fn:Optional[Callable] = None, | |||
| delete_after_train:bool = True): | |||
| """ | |||
| 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | |||
| :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
| :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
| evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: 该 metric 值是否是越大越好。 | |||
| :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||
| 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||
| @@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
| self.get_monitor_value(sanity_check_res) | |||
| def on_validate_end(self, trainer, results): | |||
| if len(results)==0: | |||
| return | |||
| monitor_value = self.get_monitor_value(results) | |||
| if monitor_value is None: | |||
| return | |||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
| if self.real_save_folder: | |||
| trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
| @@ -45,6 +45,7 @@ class RichCallback(ProgressCallback): | |||
| :param print_every: 多少个 batch 更新一次显示。 | |||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||
| 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: 是否是monitor的结果越大越好。 | |||
| :param format_json: 是否format json再打印 | |||
| """ | |||
| @@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback): | |||
| :param print_every: 多少个 batch 更新一次显示。 | |||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||
| 字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: 是否是monitor的结果越大越好。 | |||
| :param format_json: 是否format json再打印 | |||
| """ | |||
| @@ -1,9 +1,10 @@ | |||
| from typing import Optional | |||
| from typing import Optional, Union | |||
| from fastNLP.core.log.logger import logger | |||
| from difflib import SequenceMatcher | |||
| from fastNLP.core.utils.utils import _get_fun_msg | |||
| def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): | |||
| def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): | |||
| """ | |||
| 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | |||
| 匹配。 | |||
| @@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||
| :param monitor: | |||
| :param real_monitor: | |||
| :param res: | |||
| :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value | |||
| :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | |||
| 找到对应的 monitor | |||
| """ | |||
| if len(res)==0: | |||
| return monitor, 0 | |||
| return monitor, None | |||
| if callable(monitor): | |||
| try: | |||
| monitor_value = monitor(res) | |||
| except BaseException as e: | |||
| logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.") | |||
| raise e | |||
| return monitor, monitor_value | |||
| if monitor in res: | |||
| return monitor, res[monitor] | |||
| @@ -5,7 +5,7 @@ __all__ = [ | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, Dict, List, Callable, Union | |||
| from typing import Any, Dict, List, Callable, Union, Tuple | |||
| from numbers import Number | |||
| import warnings | |||
| @@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception): | |||
| self.field_name = field_name # 标示当前 field 的名称 | |||
| def _get_ele_type_and_dim(cell: Any, dim=0): | |||
| def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||
| r""" | |||
| 识别cell的类别与dimension的数量 | |||
| @@ -206,7 +206,7 @@ class AutoCollator(Collator): | |||
| def __init__(self, as_numpy: bool): | |||
| super(AutoCollator, self).__init__() | |||
| self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||
| self.need_inputs = [] # 需要的 field name | |||
| self.need_inputs = set() # 需要的 field name | |||
| self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||
| self.field_dims = None # 每列数据单元维度 | |||
| self.as_numpy = as_numpy | |||
| @@ -214,10 +214,17 @@ class AutoCollator(Collator): | |||
| def __call__(self, ins_lst: List[Dict]) -> dict: | |||
| if len(self.need_inputs) == 0: | |||
| raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||
| # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||
| # 第一种情况,设置了 set_input 的值 | |||
| # 第二种情况, 根据数据的类型的判断是否 padding | |||
| if self.field_dtypes is None and self.field_dims is None: | |||
| self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) | |||
| field_dtypes, field_dims = {}, {} | |||
| for key, value in ins_lst[0].items(): | |||
| if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||
| field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||
| self.field_dtypes = field_dtypes | |||
| self.field_dims = field_dims | |||
| pack_ins_lst, pad_ins_lst = {field_name: [] | |||
| for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||
| @@ -233,13 +240,13 @@ class AutoCollator(Collator): | |||
| if len(self.pad_field_value.keys()) > 0: | |||
| # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||
| drop_field_names = [] | |||
| non_pad_field_names = [] | |||
| for k, v in self.pad_field_value.items(): | |||
| if v is None: | |||
| drop_field_names.append(k) | |||
| non_pad_field_names.append(k) | |||
| # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||
| for field_name in drop_field_names: | |||
| for field_name in non_pad_field_names: | |||
| field_array = pack_ins_lst.pop(field_name) | |||
| pad_ins_lst[field_name] = np.array(field_array) | |||
| @@ -269,7 +276,7 @@ class AutoCollator(Collator): | |||
| def set_input(self, *field_names): | |||
| for field_name in field_names: | |||
| self.need_inputs.append(field_name) | |||
| self.need_inputs.add(field_name) | |||
| def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||
| @@ -11,11 +11,12 @@ __all__ = [ | |||
| from fastNLP.core.drivers import Driver | |||
| from fastNLP.core.drivers.utils import choose_driver | |||
| from .loops import Loop, EvaluateBatchLoop | |||
| from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ | |||
| from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | |||
| match_and_substitute_params, f_rich_progress | |||
| from fastNLP.core.metrics import Metric | |||
| from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | |||
| from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | |||
| from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
| from fastNLP.core.log import logger | |||
| @@ -38,11 +39,11 @@ class Evaluator: | |||
| driver: Union[str, Driver] = 'single', | |||
| device: Optional[Union[int, List[int], str]] = None, | |||
| batch_step_fn: Optional[callable] = None, | |||
| mode: str = "validate", | |||
| mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
| input_mapping: Optional[Union[Callable, Dict]] = None, | |||
| output_mapping: Optional[Union[Callable, Dict]] = None, | |||
| model_wo_auto_param_call: bool = False, | |||
| fp16: Optional[bool] = False, | |||
| fp16: bool = False, | |||
| verbose: int = 1, | |||
| **kwargs | |||
| ): | |||
| @@ -92,8 +93,8 @@ class Evaluator: | |||
| self.device = device | |||
| self.verbose = verbose | |||
| assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ | |||
| "two parameters." | |||
| if batch_step_fn is not None: | |||
| _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||
| self.batch_step_fn = batch_step_fn | |||
| self.mode = mode | |||
| @@ -135,6 +136,7 @@ class Evaluator: | |||
| if self.progress_bar == 'auto': | |||
| self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||
| self.driver.check_evaluator_mode(self.mode) | |||
| self.driver.barrier() | |||
| def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
| @@ -154,8 +156,6 @@ class Evaluator: | |||
| assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | |||
| assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | |||
| self.driver.check_evaluator_mode(self.mode) | |||
| if self.mode == 'validate': | |||
| assert self.driver.has_validate_dataloaders() | |||
| else: | |||
| @@ -367,9 +367,10 @@ class _MetricsWrapper: | |||
| raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||
| f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
| if isinstance(metric, Metric): | |||
| auto_param_call(metric.update, outputs, *args) | |||
| # 这样在 auto_param_call 报错的时候才清晰。 | |||
| auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||
| elif _is_torchmetrics_metric(metric): | |||
| auto_param_call(metric.update, outputs, *args) | |||
| auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||
| elif _is_allennlp_metric(metric): | |||
| auto_param_call(metric.__call__, outputs, *args) | |||
| elif _is_paddle_metric(metric): | |||
| @@ -14,6 +14,7 @@ __all__ = [ | |||
| from .loops import Loop, TrainBatchLoop | |||
| from .utils import State, TrainerState | |||
| from .utils.utils import check_validate_every | |||
| from .evaluator import Evaluator | |||
| from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | |||
| from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | |||
| @@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper | |||
| from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
| from fastNLP.core.drivers import Driver | |||
| from fastNLP.core.drivers.utils import choose_driver | |||
| from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||
| from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | |||
| from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
| from fastNLP.envs import rank_zero_call | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
| @@ -42,7 +44,7 @@ class Trainer(TrainerEventTrigger): | |||
| validate_dataloaders=None, | |||
| batch_step_fn: Optional[Callable] = None, | |||
| validate_batch_step_fn: Optional[Callable] = None, | |||
| validate_mode: str = "validate", | |||
| validate_mode: Union[str, callable] = 'validate', | |||
| callbacks: Union[List[Callback], Callback, None] = None, | |||
| metrics: Optional[dict] = None, | |||
| validate_every: Optional[Union[int, callable]] = -1, | |||
| @@ -51,7 +53,7 @@ class Trainer(TrainerEventTrigger): | |||
| model_wo_auto_param_call: bool = False, | |||
| accumulation_steps: int = 1, | |||
| fp16: bool = False, | |||
| monitor: str = None, | |||
| monitor: Union[str, callable] = None, | |||
| larger_better: bool = True, | |||
| marker: Optional[str] = None, | |||
| **kwargs | |||
| @@ -90,11 +92,8 @@ class Trainer(TrainerEventTrigger): | |||
| :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | |||
| :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||
| :param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||
| 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 | |||
| 中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed | |||
| 表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 | |||
| 例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch | |||
| 之后,每个 epoch 结束进行一次 validate 。 | |||
| 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||
| 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||
| :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | |||
| 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | |||
| value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | |||
| @@ -111,7 +110,7 @@ class Trainer(TrainerEventTrigger): | |||
| :param fp16: 是否开启混合精度训练;默认为 False; | |||
| :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
| 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。 | |||
| 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: monitor 的值是否是越大越好。 | |||
| :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | |||
| :param kwargs: 一些其它的可能需要的参数; | |||
| @@ -142,10 +141,9 @@ class Trainer(TrainerEventTrigger): | |||
| self.input_mapping = input_mapping | |||
| self.output_mapping = output_mapping | |||
| assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ | |||
| "two parameters." | |||
| self.batch_step_fn = batch_step_fn | |||
| if batch_step_fn is not None: | |||
| _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||
| self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | |||
| else: | |||
| self.check_batch_step_fn = lambda *args, **kwargs: ... | |||
| @@ -221,18 +219,11 @@ class Trainer(TrainerEventTrigger): | |||
| if metrics is not None and validate_dataloaders is None: | |||
| raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | |||
| # 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | |||
| # _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | |||
| self.evaluator = None | |||
| self.monitor = monitor | |||
| self.larger_better = larger_better | |||
| if metrics is not None and validate_dataloaders is not None: | |||
| if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
| raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
| if callable(validate_every): | |||
| logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||
| "and in this way, the kind of controlling frequency is depending on the 'step'.") | |||
| check_validate_every(validate_every) | |||
| self.evaluator = Evaluator( | |||
| model=model, | |||
| dataloaders=validate_dataloaders, | |||
| @@ -352,33 +343,32 @@ class Trainer(TrainerEventTrigger): | |||
| _validate_res: dict = validate_fn() | |||
| trainer.on_validate_end(_validate_res) | |||
| self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||
| self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||
| def step_validate(self): | |||
| if self.evaluator is not None: | |||
| should_run_validate = False | |||
| """ | |||
| 在每个 batch 结束后调用,根据设置执行 evaluate 。 | |||
| :return: | |||
| """ | |||
| if self.evaluator is not None: | |||
| if callable(self.validate_every): | |||
| if self.validate_every(self): | |||
| should_run_validate = True | |||
| elif self.validate_every > 0: | |||
| if self.global_forward_batches % self.validate_every == 0: | |||
| should_run_validate = True | |||
| if should_run_validate: | |||
| self.validate_fn() | |||
| self.run_evaluate() | |||
| elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||
| self.run_evaluate() | |||
| def epoch_validate(self): | |||
| if self.evaluator is not None: | |||
| should_run_validate = False | |||
| """ | |||
| 在每个 epoch 结束后调用,根据设置执行 evaluate 。 | |||
| :return: | |||
| """ | |||
| if self.evaluator is not None: | |||
| if isinstance(self.validate_every, int) and self.validate_every < 0: | |||
| validate_every = -self.validate_every | |||
| if self.cur_epoch_idx % validate_every == 0: | |||
| should_run_validate = True | |||
| if should_run_validate: | |||
| self.validate_fn() | |||
| self.run_evaluate() | |||
| def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | |||
| r""" | |||
| @@ -410,9 +400,7 @@ class Trainer(TrainerEventTrigger): | |||
| def wrapper(fn: Callable) -> Callable: | |||
| cls._custom_callbacks[marker].append((event, fn)) | |||
| callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | |||
| assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ | |||
| f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ | |||
| f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." | |||
| _check_valid_parameters_number(fn, callback_fn_args) | |||
| return fn | |||
| return wrapper | |||
| @@ -1,8 +1,9 @@ | |||
| from collections.abc import Iterator | |||
| import inspect | |||
| from typing import Dict | |||
| from fastNLP.core.callbacks import CallbackManager | |||
| from .state import TrainerState | |||
| from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
| class TrainerEventTrigger: | |||
| @@ -125,5 +126,8 @@ class _TruncatedDataLoader: | |||
| return getattr(self.dataloader, item) | |||
| def check_validate_every(validate_every): | |||
| if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
| raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
| if callable(validate_every): | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
| @@ -178,10 +178,11 @@ class DataSet: | |||
| elif isinstance(idx, slice): | |||
| if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | |||
| raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | |||
| data_set = DataSet() | |||
| dataset = DataSet() | |||
| for field_name, field in self.field_arrays.items(): | |||
| data_set.add_field(field_name=field_name, fields=field.content[idx]) | |||
| return data_set | |||
| dataset.add_field(field_name=field_name, fields=field.content[idx]) | |||
| dataset.collate_fns = deepcopy(self.collate_fns) | |||
| return dataset | |||
| elif isinstance(idx, str): | |||
| if idx not in self: | |||
| raise KeyError("No such field called {} in DataSet.".format(idx)) | |||
| @@ -192,6 +193,7 @@ class DataSet: | |||
| assert isinstance(i, int), "Only int index allowed." | |||
| instance = self[i] | |||
| dataset.append(instance) | |||
| dataset.collate_fns = deepcopy(self.collate_fns) | |||
| return dataset | |||
| else: | |||
| raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
| @@ -674,6 +676,8 @@ class DataSet: | |||
| dev_set.append(self[idx]) | |||
| for idx in train_indices: | |||
| train_set.append(self[idx]) | |||
| dev_set.collate_fns = deepcopy(self.collate_fns) | |||
| train_set.collate_fns = deepcopy(self.collate_fns) | |||
| return dev_set, train_set | |||
| @@ -795,7 +799,7 @@ class DataSet: | |||
| :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||
| :return: | |||
| """ | |||
| # TODO 需要去重复 | |||
| # TODO 不能为空 | |||
| for field_name in field_names: | |||
| self.collate_fns.set_pad_val(field_name, val=val) | |||
| @@ -66,7 +66,7 @@ class JittorDriver(Driver): | |||
| if mode == "validate": | |||
| if not hasattr(model, "validate_step"): | |||
| if hasattr(model, "test_step"): | |||
| logger.warning( | |||
| logger.warning_once( | |||
| "Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
| "are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
| "'validate_step'.") | |||
| @@ -74,7 +74,7 @@ class JittorDriver(Driver): | |||
| else: | |||
| if not hasattr(model, "test_step"): | |||
| if hasattr(model, "validate_step"): | |||
| logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
| logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
| "are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||
| "'test_step'.") | |||
| @@ -133,7 +133,7 @@ class PaddleDriver(Driver): | |||
| else: | |||
| if not hasattr(model, "test_step"): | |||
| if hasattr(model, "validate_step"): | |||
| logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
| logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
| "are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | |||
| "'test_step'.") | |||
| @@ -333,10 +333,8 @@ def all_gather_object(object_list, obj, group=None): | |||
| >>> output | |||
| ['foo', 12, {1: 2}] | |||
| """ | |||
| if dist._rank_not_in_group(group): | |||
| if dist.distributed_c10d._rank_not_in_group(group): | |||
| return | |||
| input_tensor, local_size = _object_to_tensor(obj) | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| current_device = torch.device("cpu") | |||
| is_nccl_backend = _check_for_nccl_backend(group) | |||
| @@ -345,10 +343,11 @@ def all_gather_object(object_list, obj, group=None): | |||
| # We cannot simply use my_rank since rank == device is not necessarily | |||
| # true. | |||
| current_device = torch.device("cuda", torch.cuda.current_device()) | |||
| input_tensor = input_tensor.to(current_device) | |||
| local_size = local_size.to(current_device) | |||
| else: | |||
| current_device = torch.cuda.current_device() | |||
| input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||
| # Gather all local sizes. This is so that we can find the max size, and index | |||
| # until the correct size when deserializing the tensors. | |||
| group_size = dist.get_world_size(group=group) | |||
| @@ -379,3 +378,4 @@ def all_gather_object(object_list, obj, group=None): | |||
| tensor = tensor.cpu() | |||
| tensor_size = object_size_list[i] | |||
| object_list[i] = _tensor_to_object(tensor, tensor_size) | |||
| return object_list | |||
| @@ -113,7 +113,7 @@ class TorchDriver(Driver): | |||
| if mode == "validate": | |||
| if not hasattr(model, "validate_step"): | |||
| if hasattr(model, "test_step"): | |||
| logger.warning( | |||
| logger.warning_once( | |||
| "Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
| "are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
| "'validate_step'.") | |||
| @@ -125,9 +125,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
| self._warning_msgs.add(msg) | |||
| def warn(self, msg, *args, **kwargs): | |||
| warnings.warn("The 'warn' method is deprecated, " | |||
| "use 'warning' instead", DeprecationWarning, 2) | |||
| self.warning(msg, *args, **kwargs) | |||
| if self.isEnabledFor(WARNING): | |||
| kwargs = self._add_rank_info(kwargs) | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| def error(self, msg, *args, **kwargs): | |||
| """ | |||
| @@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask | |||
| class Accuracy(Metric): | |||
| def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||
| aggregate_when_get_metric: bool = True): | |||
| def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||
| super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
| self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | |||
| self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | |||
| @@ -64,7 +63,7 @@ class Accuracy(Metric): | |||
| warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
| else: | |||
| raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||
| raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | |||
| f"{pred.shape[:-1]}, got {target.shape}.") | |||
| if masks is not None: | |||
| @@ -23,14 +23,14 @@ __all__ = [ | |||
| "BucketedBatchSampler", | |||
| "ReproducibleBatchSampler", | |||
| "re_instantiate_sampler", | |||
| "conversion_between_reproducible_and_unrepeated_sampler" | |||
| "re_instantiate_sampler" | |||
| ] | |||
| from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | |||
| from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||
| from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
| from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||
| from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||
| from .utils import re_instantiate_sampler | |||
| from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | |||
| from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||
| @@ -0,0 +1,33 @@ | |||
| from fastNLP.core.samplers import re_instantiate_sampler | |||
| from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \ | |||
| SortedSampler | |||
| from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \ | |||
| UnrepeatedSequentialSampler, UnrepeatedSortedSampler | |||
| def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
| """ | |||
| 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
| ReproducibleSampler, | |||
| :param sampler: | |||
| :return: | |||
| """ | |||
| assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
| "The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
| if isinstance(sampler, UnrepeatedSampler): | |||
| if isinstance(sampler, UnrepeatedRandomSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
| elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
| elif isinstance(sampler, UnrepeatedSortedSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
| raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
| else: | |||
| if isinstance(sampler, RandomSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
| elif isinstance(sampler, SequentialSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
| elif isinstance(sampler, SortedSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
| raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||
| @@ -378,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| batch_indices = list(batch_indices[:-1]) | |||
| rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | |||
| rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | |||
| batches = (np.array(batches)[batch_indices]).tolist() | |||
| batches = (np.array(batches, dtype=object)[batch_indices]).tolist() | |||
| if last_batches: | |||
| batches = batches + last_batches | |||
| return batches | |||
| @@ -387,19 +387,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | |||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
| " consumed. ") | |||
| states = { | |||
| 'seed': self.seed, | |||
| 'epoch': self.epoch, | |||
| 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
| 'sampler_type': self.__class__.__name__, | |||
| 'length': len(self.dataset), | |||
| 'shuffle': self.shuffle, | |||
| 'batch_size': self.batch_size, | |||
| 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
| 'num_replicas': self.num_replicas | |||
| } | |||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
| 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
| 'num_replicas': self.num_replicas, | |||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
| states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| @@ -1,3 +1,10 @@ | |||
| __all__ = [ | |||
| 'ReproducibleSampler', | |||
| 'RandomSampler', | |||
| "SortedSampler", | |||
| "SequentialSampler" | |||
| ] | |||
| from typing import Dict, List, Union | |||
| import math | |||
| import os | |||
| @@ -10,13 +17,6 @@ from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
| from .utils import NumConsumedSamplesArray | |||
| __all__ = [ | |||
| 'ReproducibleSampler', | |||
| 'RandomSampler', | |||
| "SortedSampler", | |||
| "SequentialSampler" | |||
| ] | |||
| class ReproducibleSampler: | |||
| """ | |||
| @@ -1,42 +1,10 @@ | |||
| __all__ = [ | |||
| 're_instantiate_sampler', | |||
| 'conversion_between_reproducible_and_unrepeated_sampler' | |||
| 're_instantiate_sampler' | |||
| ] | |||
| from array import array | |||
| from typing import Sequence | |||
| from collections import deque | |||
| from fastNLP.core.samplers.unrepeated_sampler import * | |||
| from fastNLP.core.samplers.reproducible_sampler import * | |||
| def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
| """ | |||
| 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
| ReproducibleSampler, | |||
| :param sampler: | |||
| :return: | |||
| """ | |||
| assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
| "The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
| if isinstance(sampler, UnrepeatedSampler): | |||
| if isinstance(sampler, UnrepeatedRandomSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
| elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
| elif isinstance(sampler, UnrepeatedSortedSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
| raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
| else: | |||
| if isinstance(sampler, RandomSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
| elif isinstance(sampler, SequentialSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
| elif isinstance(sampler, SortedSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
| raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||
| def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
| all_attributes = vars(sampler) | |||
| @@ -13,7 +13,6 @@ __all__ = [ | |||
| 'torch_paddle_move_data_to_device', | |||
| 'torch_move_data_to_device', | |||
| 'get_fn_arg_names', | |||
| 'check_fn_not_empty_params', | |||
| 'auto_param_call', | |||
| 'check_user_specific_params', | |||
| 'dataclass_to_dict', | |||
| @@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||
| from .rich_progress import f_rich_progress | |||
| from .torch_paddle_utils import torch_paddle_move_data_to_device | |||
| from .torch_utils import torch_move_data_to_device | |||
| from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||
| from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | |||
| dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||
| indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||
| @@ -1,3 +1,4 @@ | |||
| import functools | |||
| import inspect | |||
| from inspect import Parameter | |||
| import dataclasses | |||
| @@ -24,10 +25,8 @@ from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||
| __all__ = [ | |||
| 'get_fn_arg_names', | |||
| 'check_fn_not_empty_params', | |||
| 'auto_param_call', | |||
| 'check_user_specific_params', | |||
| 'dataclass_to_dict', | |||
| @@ -54,30 +53,6 @@ def get_fn_arg_names(fn: Callable) -> List[str]: | |||
| return list(inspect.signature(fn).parameters) | |||
| def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||
| r""" | |||
| 检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||
| 用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||
| :param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||
| :param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||
| :return: bool,表示传入的 `batch_step_fn` 是否正确; | |||
| """ | |||
| if fn is None: | |||
| return True | |||
| if not callable(fn): | |||
| return False | |||
| else: | |||
| params = inspect.signature(fn).parameters | |||
| not_default_params = {} | |||
| for _name, _param in params.items(): | |||
| if _param.default == Parameter.empty: | |||
| not_default_params[_name] = _param | |||
| return len(not_default_params) == param_num | |||
| def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
| mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
| r""" | |||
| @@ -95,7 +70,6 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
| :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||
| 参数值后,再传给 `fn` 进行实际的运算; | |||
| :param mapping: 一个字典,用来更改其前面的字典的键值; | |||
| :param wo_auto_param_call: 是否关闭默认的参数匹配行为; | |||
| :return: 返回 `fn` 运行的结果; | |||
| @@ -123,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
| _kwargs = None | |||
| for _name, _param in _need_params.items(): | |||
| if _param.kind == Parameter.VAR_POSITIONAL: | |||
| raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||
| fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
| raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.") | |||
| if _param.kind == Parameter.VAR_KEYWORD: | |||
| _kwargs = (_name, _param) | |||
| @@ -136,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
| _default_params[_name] = _param.default | |||
| if mapping is not None: | |||
| assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
| fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
| assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \ | |||
| f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
| _has_params = {} | |||
| duplicate_names = [] | |||
| for arg in args: | |||
| assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||
| if not isinstance(arg, Dict): | |||
| fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
| raise TypeError(f"Exception happens when calling {fn_msg}. " | |||
| f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.") | |||
| for _name, _value in arg.items(): | |||
| if mapping is not None and _name in mapping: | |||
| _name = mapping[_name] | |||
| @@ -153,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
| elif _name in _need_params and not (_has_params[_name] is _value): | |||
| duplicate_names.append(_name) | |||
| if duplicate_names: | |||
| raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||
| fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
| raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.") | |||
| # 将具有默认值但是没有被输入修改过的参数值传进去; | |||
| for _name, _value in _default_params.items(): | |||
| @@ -162,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
| if len(_has_params)<len(_need_params): | |||
| miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | |||
| raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||
| fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
| _provided_keys = _get_keys(args) | |||
| raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} " | |||
| f"are not found in the input keys({_provided_keys}).") | |||
| return fn(**_has_params) | |||
| def _get_keys(args:List[Dict]) -> List[List[str]]: | |||
| """ | |||
| 返回每个 dict 的 keys | |||
| :param args: | |||
| :return: | |||
| """ | |||
| _provided_keys = [] | |||
| for arg in args: | |||
| _provided_keys.append(list(arg.keys())) | |||
| return _provided_keys | |||
| def _get_fun_msg(fn)->str: | |||
| """ | |||
| 获取函数的基本信息,帮助报错。 | |||
| ex: | |||
| print(_get_fun_msg(_get_fun_msg)) | |||
| # `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | |||
| :param callable fn: | |||
| :return: | |||
| """ | |||
| if isinstance(fn, functools.partial): | |||
| return _get_fun_msg(fn.func) | |||
| try: | |||
| fn_name = fn.__qualname__ + str(inspect.signature(fn)) | |||
| except: | |||
| fn_name = str(fn) | |||
| try: | |||
| fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||
| except: | |||
| fp = '' | |||
| msg = f'`{fn_name}`' + fp | |||
| return msg | |||
| def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||
| """ | |||
| 检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 | |||
| 进行报错。 | |||
| :param fn: 需要检测的函数,可以是 method 或者 function 。 | |||
| :param expected_params: 期待应该支持的参数。 | |||
| :param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。 | |||
| :return: | |||
| """ | |||
| if fn_name is not None: | |||
| assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." | |||
| parameters = list(inspect.signature(fn).parameters.values()) | |||
| if inspect.ismethod(fn): | |||
| if len(parameters)>0 and parameters[0].name == 'self': | |||
| parameters = parameters[1:] # 去掉self | |||
| no_var_param = True # 没有 * 这种参数 | |||
| number_param_need_value = 0 | |||
| for param in parameters: | |||
| if param.kind is param.VAR_POSITIONAL: | |||
| no_var_param = False | |||
| elif param.kind is param.VAR_KEYWORD: | |||
| no_var_param = False | |||
| else: | |||
| if param.default is param.empty: | |||
| number_param_need_value += 1 | |||
| if len(parameters)<len(expected_params) and no_var_param: | |||
| raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||
| f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||
| if number_param_need_value>len(expected_params): | |||
| raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||
| f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||
| def check_user_specific_params(user_params: Dict, fn: Callable): | |||
| """ | |||
| 该函数使用用户的输入来对指定函数的参数进行赋值; | |||
| @@ -592,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||
| wait_to_success(path.exists) | |||
| def get_class_that_defined_method(method): | |||
| """ | |||
| 给定一个method,返回这个 method 的 class 的对象 | |||
| :param method: | |||
| :return: | |||
| """ | |||
| if isinstance(method, functools.partial): | |||
| return get_class_that_defined_method(method.func) | |||
| if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)): | |||
| for cls in inspect.getmro(method.__self__.__class__): | |||
| if method.__name__ in cls.__dict__: | |||
| return cls | |||
| method = getattr(method, '__func__', method) # fallback to __qualname__ parsing | |||
| if inspect.isfunction(method): | |||
| cls = getattr(inspect.getmodule(method), | |||
| method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||
| None) | |||
| if isinstance(cls, type): | |||
| return cls | |||
| return getattr(method, '__objclass__', None) # handle special descriptor objects | |||
| @@ -251,10 +251,10 @@ class DataBundle: | |||
| def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | |||
| ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 | |||
| .. note:: | |||
| ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply`` 区别的介绍。 | |||
| :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| @@ -285,7 +285,7 @@ class DataBundle: | |||
| def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
| progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||
| 对DataBundle中所有的dataset使用apply方法 | |||
| @@ -309,10 +309,10 @@ class DataBundle: | |||
| def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
| progress_desc: str = '', show_progress_bar: bool = True): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 | |||
| .. note:: | |||
| ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply`` 区别的介绍。 | |||
| :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| @@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): | |||
| def process_from_file(self, paths) -> DataBundle: | |||
| r""" | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
| :param paths: | |||
| :return: DataBundle | |||
| @@ -164,7 +164,7 @@ class GraphBuilderBase: | |||
| def build_graph_from_file(self, path: str): | |||
| r""" | |||
| 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
| 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
| :param path: | |||
| :return: scipy_sparse_matrix | |||
| @@ -33,7 +33,7 @@ class Pipe: | |||
| def process_from_file(self, paths: str) -> DataBundle: | |||
| r""" | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
| :param str paths: | |||
| :return: DataBundle | |||
| @@ -0,0 +1,187 @@ | |||
| from functools import partial | |||
| import pytest | |||
| from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg | |||
| from fastNLP.core.metrics import Metric | |||
| class TestAutoParamCall: | |||
| def test_basic(self): | |||
| def fn(x): | |||
| return x | |||
| x = {'x': 3, 'y': 4} | |||
| r = auto_param_call(fn, x) | |||
| assert r==3 | |||
| xs = [] | |||
| for i in range(10): | |||
| xs.append({f'x{i}': i}) | |||
| def fn(x0, x1, x2, x3): | |||
| return x0 + x1 + x2 + x3 | |||
| r = auto_param_call(fn, *xs) | |||
| assert r == 0 + 1+ 2+ 3 | |||
| def fn(chongfu1, chongfu2, buChongFu): | |||
| pass | |||
| with pytest.raises(BaseException) as exc_info: | |||
| auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2}, | |||
| {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||
| assert 'The following key present in several inputs' in exc_info.value.args[0] | |||
| assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0] | |||
| # 没用到不报错 | |||
| def fn(chongfu1, buChongFu): | |||
| pass | |||
| auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2}, | |||
| {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||
| # 可以定制signature_fn | |||
| def fn1(**kwargs): | |||
| kwargs.pop('x') | |||
| kwargs.pop('y') | |||
| assert len(kwargs)==0 | |||
| def fn(x, y): | |||
| pass | |||
| x = {'x': 3, 'y': 4} | |||
| r = auto_param_call(fn1, x, signature_fn=fn) | |||
| # 没提供的时候报错 | |||
| def fn(meiti1, meiti2, tigong): | |||
| pass | |||
| with pytest.raises(BaseException) as exc_info: | |||
| auto_param_call(fn, {'tigong':1}) | |||
| assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0] | |||
| # 默认值替换 | |||
| def fn(x, y=100): | |||
| return x + y | |||
| r = auto_param_call(fn, {'x': 10, 'y': 20}) | |||
| assert r==30 | |||
| assert auto_param_call(fn, {'x': 10, 'z': 20})==110 | |||
| # 测试mapping的使用 | |||
| def fn(x, y=100): | |||
| return x + y | |||
| r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'}) | |||
| assert r==30 | |||
| # 测试不需要任何参数 | |||
| def fn(): | |||
| return 1 | |||
| assert 1 == auto_param_call(fn, {'x':1}) | |||
| # 测试调用类的方法没问题 | |||
| assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1}) | |||
| assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'}) | |||
| def test_msg(self): | |||
| with pytest.raises(BaseException) as exc_info: | |||
| auto_param_call(self.call_this, {'x':1}) | |||
| assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] | |||
| with pytest.raises(BaseException) as exc_info: | |||
| auto_param_call(call_this_for_auto_param_call, {'x':1}) | |||
| assert __file__ in exc_info.value.args[0] | |||
| assert 'call_this_for_auto_param_call' in exc_info.value.args[0] | |||
| with pytest.raises(BaseException) as exc_info: | |||
| auto_param_call(self.call_this_two, {'x':1}) | |||
| assert __file__ in exc_info.value.args[0] | |||
| with pytest.raises(BaseException) as exc_info: | |||
| auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this) | |||
| assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息 | |||
| def call_this(self, x, y): | |||
| return x + y | |||
| def call_this_two(self, x, y, z=pytest, **kwargs): | |||
| return x + y | |||
| def test_metric_auto_param_call(self): | |||
| metric = AutoParamCallMetric() | |||
| with pytest.raises(BaseException): | |||
| auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__) | |||
| class AutoParamCallMetric(Metric): | |||
| def update(self, x): | |||
| pass | |||
| def call_this_for_auto_param_call(x, y): | |||
| return x + y | |||
| class TestCheckNumberOfParameters: | |||
| def test_validate_every(self): | |||
| def validate_every(trainer): | |||
| pass | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
| # 无默认值,多了报错 | |||
| def validate_every(trainer, other): | |||
| pass | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
| assert "2 parameters" in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 有默认值ok | |||
| def validate_every(trainer, other=1): | |||
| pass | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
| # 参数多了 | |||
| def validate_every(trainer): | |||
| pass | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||
| assert "accepts 1 parameters" in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 使用partial | |||
| def validate_every(trainer, other): | |||
| pass | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||
| assert 'accepts 2 parameters' in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 如果存在 *args 或 *kwargs 不报错多的 | |||
| def validate_every(trainer, *args): | |||
| pass | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more']) | |||
| def validate_every(trainer, **kwargs): | |||
| pass | |||
| _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
| # class 的方法删掉self | |||
| class InnerClass: | |||
| def demo(self, x): | |||
| pass | |||
| def no_param(self): | |||
| pass | |||
| def param_kwargs(self, **kwargs): | |||
| pass | |||
| inner = InnerClass() | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||
| assert 'accepts 1 parameters' in exc_info.value.args[0] | |||
| _check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||
| def test_get_fun_msg(): | |||
| def demo(x): | |||
| pass | |||
| print(_get_fun_msg(_get_fun_msg)) | |||
| @@ -2,37 +2,19 @@ import os | |||
| import sys | |||
| import __main__ | |||
| from functools import wraps | |||
| import inspect | |||
| from inspect import ismethod | |||
| import functools | |||
| from copy import deepcopy | |||
| from io import StringIO | |||
| import time | |||
| import numpy as np | |||
| from fastNLP.core.utils.utils import get_class_that_defined_method | |||
| from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
| from fastNLP.core.drivers.utils import distributed_open_proc | |||
| from fastNLP.core.log import logger | |||
| def get_class_that_defined_method(meth): | |||
| if isinstance(meth, functools.partial): | |||
| return get_class_that_defined_method(meth.func) | |||
| if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): | |||
| for cls in inspect.getmro(meth.__self__.__class__): | |||
| if meth.__name__ in cls.__dict__: | |||
| return cls | |||
| meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing | |||
| if inspect.isfunction(meth): | |||
| cls = getattr(inspect.getmodule(meth), | |||
| meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||
| None) | |||
| if isinstance(cls, type): | |||
| return cls | |||
| return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||
| def recover_logger(fn): | |||
| @wraps(fn) | |||
| def wrapper(*args, **kwargs): | |||