@@ -10,6 +10,7 @@ from .utils import _get_monitor_value | |||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
class Callback: | |||
@@ -299,7 +300,11 @@ class HasMonitorCallback(Callback): | |||
self.must_have_moinitor = must_have_monitor | |||
def set_monitor(self, monitor, larger_better): | |||
self.monitor = str(monitor) if monitor is not None else None | |||
if callable(monitor): # 检查是否能够接受一个参数 | |||
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') | |||
self.monitor = monitor | |||
else: | |||
self.monitor = str(monitor) if monitor is not None else None | |||
self.larger_better = bool(larger_better) | |||
if larger_better: | |||
self.monitor_value = float('-inf') | |||
@@ -322,24 +327,33 @@ class HasMonitorCallback(Callback): | |||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||
f"You can set it in the initialization or through Trainer.") | |||
def get_monitor_value(self, results:Dict)->float: | |||
def get_monitor_value(self, results:Dict)->Union[float, None]: | |||
""" | |||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||
:param results: | |||
:return: | |||
:return: 如果为 None ,表明此次没有找到合适的monitor | |||
""" | |||
if len(results)==0: | |||
return 0 | |||
return None | |||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | |||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||
logger.warning( | |||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
if monitor_value is None: | |||
return monitor_value | |||
# 第一次运行 | |||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||
# 检测到此次和上次不同。 | |||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||
f"customized monitor function when the evaluation results are varying between validation.") | |||
self._real_monitor = use_monitor | |||
return monitor_value | |||
@@ -347,10 +361,12 @@ class HasMonitorCallback(Callback): | |||
""" | |||
检测 monitor_value 是否是更好的 | |||
:param monitor_value: | |||
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||
:return: | |||
""" | |||
if monitor_value is None: | |||
return False | |||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||
if keep_if_better and better: | |||
self.monitor_value = monitor_value | |||
@@ -364,6 +380,12 @@ class HasMonitorCallback(Callback): | |||
:param monitor_value2: | |||
:return: | |||
""" | |||
if monitor_value1 is None and monitor_value2 is None: | |||
return True | |||
if monitor_value1 is None: | |||
return False | |||
if monitor_value2 is None: | |||
return True | |||
better = False | |||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||
(not self.larger_better and monitor_value1 < monitor_value2): | |||
@@ -10,8 +10,7 @@ from copy import deepcopy | |||
import fastNLP | |||
from .callback import Callback, HasMonitorCallback | |||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||
from .callback import HasMonitorCallback | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
@@ -166,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback): | |||
""" | |||
if self.save_topk is not None: | |||
monitor_value = self.get_monitor_value(results=results) | |||
if monitor_value is None: | |||
return | |||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||
f"-{self._real_monitor}_{monitor_value}" | |||
@@ -231,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -278,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
@@ -2,17 +2,18 @@ __all__ = [ | |||
'EarlyStopCallback' | |||
] | |||
from typing import Dict | |||
from typing import Dict, Union, Callable | |||
from .callback import HasMonitorCallback | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
class EarlyStopCallback(HasMonitorCallback): | |||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||
""" | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
:param patience: 多少次 validate 不没有提升就停止。 | |||
""" | |||
@@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||
self.patience = patience | |||
def on_validate_end(self, trainer, results): | |||
if len(results)==0: | |||
return | |||
monitor_value = self.get_monitor_value(results) | |||
if monitor_value is None: | |||
return | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
self.wait = 0 | |||
else: | |||
@@ -3,7 +3,7 @@ __all__ = [ | |||
] | |||
import os | |||
from typing import Optional, Callable | |||
from typing import Optional, Callable, Union | |||
from .callback import HasMonitorCallback | |||
from io import BytesIO | |||
import shutil | |||
@@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call | |||
class LoadBestModelCallback(HasMonitorCallback): | |||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | |||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | |||
model_load_fn:Optional[Callable] = None, | |||
delete_after_train:bool = True): | |||
""" | |||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 该 metric 值是否是越大越好。 | |||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||
@@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
self.get_monitor_value(sanity_check_res) | |||
def on_validate_end(self, trainer, results): | |||
if len(results)==0: | |||
return | |||
monitor_value = self.get_monitor_value(results) | |||
if monitor_value is None: | |||
return | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if self.real_save_folder: | |||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
@@ -45,6 +45,7 @@ class RichCallback(ProgressCallback): | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
@@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback): | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||
字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
@@ -1,9 +1,10 @@ | |||
from typing import Optional | |||
from typing import Optional, Union | |||
from fastNLP.core.log.logger import logger | |||
from difflib import SequenceMatcher | |||
from fastNLP.core.utils.utils import _get_fun_msg | |||
def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): | |||
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): | |||
""" | |||
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | |||
匹配。 | |||
@@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||
:param monitor: | |||
:param real_monitor: | |||
:param res: | |||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value | |||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | |||
找到对应的 monitor | |||
""" | |||
if len(res)==0: | |||
return monitor, 0 | |||
return monitor, None | |||
if callable(monitor): | |||
try: | |||
monitor_value = monitor(res) | |||
except BaseException as e: | |||
logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.") | |||
raise e | |||
return monitor, monitor_value | |||
if monitor in res: | |||
return monitor, res[monitor] | |||
@@ -5,7 +5,7 @@ __all__ = [ | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, Dict, List, Callable, Union | |||
from typing import Any, Dict, List, Callable, Union, Tuple | |||
from numbers import Number | |||
import warnings | |||
@@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception): | |||
self.field_name = field_name # 标示当前 field 的名称 | |||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||
r""" | |||
识别cell的类别与dimension的数量 | |||
@@ -206,7 +206,7 @@ class AutoCollator(Collator): | |||
def __init__(self, as_numpy: bool): | |||
super(AutoCollator, self).__init__() | |||
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||
self.need_inputs = [] # 需要的 field name | |||
self.need_inputs = set() # 需要的 field name | |||
self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||
self.field_dims = None # 每列数据单元维度 | |||
self.as_numpy = as_numpy | |||
@@ -214,10 +214,17 @@ class AutoCollator(Collator): | |||
def __call__(self, ins_lst: List[Dict]) -> dict: | |||
if len(self.need_inputs) == 0: | |||
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||
# 第一种情况,设置了 set_input 的值 | |||
# 第二种情况, 根据数据的类型的判断是否 padding | |||
if self.field_dtypes is None and self.field_dims is None: | |||
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) | |||
field_dtypes, field_dims = {}, {} | |||
for key, value in ins_lst[0].items(): | |||
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||
self.field_dtypes = field_dtypes | |||
self.field_dims = field_dims | |||
pack_ins_lst, pad_ins_lst = {field_name: [] | |||
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||
@@ -233,13 +240,13 @@ class AutoCollator(Collator): | |||
if len(self.pad_field_value.keys()) > 0: | |||
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||
drop_field_names = [] | |||
non_pad_field_names = [] | |||
for k, v in self.pad_field_value.items(): | |||
if v is None: | |||
drop_field_names.append(k) | |||
non_pad_field_names.append(k) | |||
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||
for field_name in drop_field_names: | |||
for field_name in non_pad_field_names: | |||
field_array = pack_ins_lst.pop(field_name) | |||
pad_ins_lst[field_name] = np.array(field_array) | |||
@@ -269,7 +276,7 @@ class AutoCollator(Collator): | |||
def set_input(self, *field_names): | |||
for field_name in field_names: | |||
self.need_inputs.append(field_name) | |||
self.need_inputs.add(field_name) | |||
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||
@@ -11,11 +11,12 @@ __all__ = [ | |||
from fastNLP.core.drivers import Driver | |||
from fastNLP.core.drivers.utils import choose_driver | |||
from .loops import Loop, EvaluateBatchLoop | |||
from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ | |||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | |||
match_and_substitute_params, f_rich_progress | |||
from fastNLP.core.metrics import Metric | |||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | |||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
from fastNLP.core.log import logger | |||
@@ -38,11 +39,11 @@ class Evaluator: | |||
driver: Union[str, Driver] = 'single', | |||
device: Optional[Union[int, List[int], str]] = None, | |||
batch_step_fn: Optional[callable] = None, | |||
mode: str = "validate", | |||
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
model_wo_auto_param_call: bool = False, | |||
fp16: Optional[bool] = False, | |||
fp16: bool = False, | |||
verbose: int = 1, | |||
**kwargs | |||
): | |||
@@ -92,8 +93,8 @@ class Evaluator: | |||
self.device = device | |||
self.verbose = verbose | |||
assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ | |||
"two parameters." | |||
if batch_step_fn is not None: | |||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||
self.batch_step_fn = batch_step_fn | |||
self.mode = mode | |||
@@ -135,6 +136,7 @@ class Evaluator: | |||
if self.progress_bar == 'auto': | |||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||
self.driver.check_evaluator_mode(self.mode) | |||
self.driver.barrier() | |||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
@@ -154,8 +156,6 @@ class Evaluator: | |||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | |||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | |||
self.driver.check_evaluator_mode(self.mode) | |||
if self.mode == 'validate': | |||
assert self.driver.has_validate_dataloaders() | |||
else: | |||
@@ -367,9 +367,10 @@ class _MetricsWrapper: | |||
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
if isinstance(metric, Metric): | |||
auto_param_call(metric.update, outputs, *args) | |||
# 这样在 auto_param_call 报错的时候才清晰。 | |||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||
elif _is_torchmetrics_metric(metric): | |||
auto_param_call(metric.update, outputs, *args) | |||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||
elif _is_allennlp_metric(metric): | |||
auto_param_call(metric.__call__, outputs, *args) | |||
elif _is_paddle_metric(metric): | |||
@@ -14,6 +14,7 @@ __all__ = [ | |||
from .loops import Loop, TrainBatchLoop | |||
from .utils import State, TrainerState | |||
from .utils.utils import check_validate_every | |||
from .evaluator import Evaluator | |||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | |||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | |||
@@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper | |||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
from fastNLP.core.drivers import Driver | |||
from fastNLP.core.drivers.utils import choose_driver | |||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
@@ -42,7 +44,7 @@ class Trainer(TrainerEventTrigger): | |||
validate_dataloaders=None, | |||
batch_step_fn: Optional[Callable] = None, | |||
validate_batch_step_fn: Optional[Callable] = None, | |||
validate_mode: str = "validate", | |||
validate_mode: Union[str, callable] = 'validate', | |||
callbacks: Union[List[Callback], Callback, None] = None, | |||
metrics: Optional[dict] = None, | |||
validate_every: Optional[Union[int, callable]] = -1, | |||
@@ -51,7 +53,7 @@ class Trainer(TrainerEventTrigger): | |||
model_wo_auto_param_call: bool = False, | |||
accumulation_steps: int = 1, | |||
fp16: bool = False, | |||
monitor: str = None, | |||
monitor: Union[str, callable] = None, | |||
larger_better: bool = True, | |||
marker: Optional[str] = None, | |||
**kwargs | |||
@@ -90,11 +92,8 @@ class Trainer(TrainerEventTrigger): | |||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | |||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 | |||
中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed | |||
表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 | |||
例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch | |||
之后,每个 epoch 结束进行一次 validate 。 | |||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | |||
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | |||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | |||
@@ -111,7 +110,7 @@ class Trainer(TrainerEventTrigger): | |||
:param fp16: 是否开启混合精度训练;默认为 False; | |||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。 | |||
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: monitor 的值是否是越大越好。 | |||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | |||
:param kwargs: 一些其它的可能需要的参数; | |||
@@ -142,10 +141,9 @@ class Trainer(TrainerEventTrigger): | |||
self.input_mapping = input_mapping | |||
self.output_mapping = output_mapping | |||
assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ | |||
"two parameters." | |||
self.batch_step_fn = batch_step_fn | |||
if batch_step_fn is not None: | |||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||
self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | |||
else: | |||
self.check_batch_step_fn = lambda *args, **kwargs: ... | |||
@@ -221,18 +219,11 @@ class Trainer(TrainerEventTrigger): | |||
if metrics is not None and validate_dataloaders is None: | |||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | |||
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | |||
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | |||
self.evaluator = None | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if metrics is not None and validate_dataloaders is not None: | |||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
if callable(validate_every): | |||
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||
"and in this way, the kind of controlling frequency is depending on the 'step'.") | |||
check_validate_every(validate_every) | |||
self.evaluator = Evaluator( | |||
model=model, | |||
dataloaders=validate_dataloaders, | |||
@@ -352,33 +343,32 @@ class Trainer(TrainerEventTrigger): | |||
_validate_res: dict = validate_fn() | |||
trainer.on_validate_end(_validate_res) | |||
self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||
def step_validate(self): | |||
if self.evaluator is not None: | |||
should_run_validate = False | |||
""" | |||
在每个 batch 结束后调用,根据设置执行 evaluate 。 | |||
:return: | |||
""" | |||
if self.evaluator is not None: | |||
if callable(self.validate_every): | |||
if self.validate_every(self): | |||
should_run_validate = True | |||
elif self.validate_every > 0: | |||
if self.global_forward_batches % self.validate_every == 0: | |||
should_run_validate = True | |||
if should_run_validate: | |||
self.validate_fn() | |||
self.run_evaluate() | |||
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||
self.run_evaluate() | |||
def epoch_validate(self): | |||
if self.evaluator is not None: | |||
should_run_validate = False | |||
""" | |||
在每个 epoch 结束后调用,根据设置执行 evaluate 。 | |||
:return: | |||
""" | |||
if self.evaluator is not None: | |||
if isinstance(self.validate_every, int) and self.validate_every < 0: | |||
validate_every = -self.validate_every | |||
if self.cur_epoch_idx % validate_every == 0: | |||
should_run_validate = True | |||
if should_run_validate: | |||
self.validate_fn() | |||
self.run_evaluate() | |||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | |||
r""" | |||
@@ -410,9 +400,7 @@ class Trainer(TrainerEventTrigger): | |||
def wrapper(fn: Callable) -> Callable: | |||
cls._custom_callbacks[marker].append((event, fn)) | |||
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | |||
assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ | |||
f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ | |||
f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." | |||
_check_valid_parameters_number(fn, callback_fn_args) | |||
return fn | |||
return wrapper | |||
@@ -1,8 +1,9 @@ | |||
from collections.abc import Iterator | |||
import inspect | |||
from typing import Dict | |||
from fastNLP.core.callbacks import CallbackManager | |||
from .state import TrainerState | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
class TrainerEventTrigger: | |||
@@ -125,5 +126,8 @@ class _TruncatedDataLoader: | |||
return getattr(self.dataloader, item) | |||
def check_validate_every(validate_every): | |||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||
if callable(validate_every): | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) |
@@ -178,10 +178,11 @@ class DataSet: | |||
elif isinstance(idx, slice): | |||
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | |||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | |||
data_set = DataSet() | |||
dataset = DataSet() | |||
for field_name, field in self.field_arrays.items(): | |||
data_set.add_field(field_name=field_name, fields=field.content[idx]) | |||
return data_set | |||
dataset.add_field(field_name=field_name, fields=field.content[idx]) | |||
dataset.collate_fns = deepcopy(self.collate_fns) | |||
return dataset | |||
elif isinstance(idx, str): | |||
if idx not in self: | |||
raise KeyError("No such field called {} in DataSet.".format(idx)) | |||
@@ -192,6 +193,7 @@ class DataSet: | |||
assert isinstance(i, int), "Only int index allowed." | |||
instance = self[i] | |||
dataset.append(instance) | |||
dataset.collate_fns = deepcopy(self.collate_fns) | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
@@ -674,6 +676,8 @@ class DataSet: | |||
dev_set.append(self[idx]) | |||
for idx in train_indices: | |||
train_set.append(self[idx]) | |||
dev_set.collate_fns = deepcopy(self.collate_fns) | |||
train_set.collate_fns = deepcopy(self.collate_fns) | |||
return dev_set, train_set | |||
@@ -795,7 +799,7 @@ class DataSet: | |||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||
:return: | |||
""" | |||
# TODO 需要去重复 | |||
# TODO 不能为空 | |||
for field_name in field_names: | |||
self.collate_fns.set_pad_val(field_name, val=val) | |||
@@ -66,7 +66,7 @@ class JittorDriver(Driver): | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
logger.warning_once( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
@@ -74,7 +74,7 @@ class JittorDriver(Driver): | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
@@ -133,7 +133,7 @@ class PaddleDriver(Driver): | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
@@ -333,10 +333,8 @@ def all_gather_object(object_list, obj, group=None): | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
""" | |||
if dist._rank_not_in_group(group): | |||
if dist.distributed_c10d._rank_not_in_group(group): | |||
return | |||
input_tensor, local_size = _object_to_tensor(obj) | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
current_device = torch.device("cpu") | |||
is_nccl_backend = _check_for_nccl_backend(group) | |||
@@ -345,10 +343,11 @@ def all_gather_object(object_list, obj, group=None): | |||
# We cannot simply use my_rank since rank == device is not necessarily | |||
# true. | |||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||
input_tensor = input_tensor.to(current_device) | |||
local_size = local_size.to(current_device) | |||
else: | |||
current_device = torch.cuda.current_device() | |||
input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||
# Gather all local sizes. This is so that we can find the max size, and index | |||
# until the correct size when deserializing the tensors. | |||
group_size = dist.get_world_size(group=group) | |||
@@ -379,3 +378,4 @@ def all_gather_object(object_list, obj, group=None): | |||
tensor = tensor.cpu() | |||
tensor_size = object_size_list[i] | |||
object_list[i] = _tensor_to_object(tensor, tensor_size) | |||
return object_list |
@@ -113,7 +113,7 @@ class TorchDriver(Driver): | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
logger.warning_once( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
@@ -125,9 +125,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
self._warning_msgs.add(msg) | |||
def warn(self, msg, *args, **kwargs): | |||
warnings.warn("The 'warn' method is deprecated, " | |||
"use 'warning' instead", DeprecationWarning, 2) | |||
self.warning(msg, *args, **kwargs) | |||
if self.isEnabledFor(WARNING): | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
def error(self, msg, *args, **kwargs): | |||
""" | |||
@@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask | |||
class Accuracy(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = True): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | |||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | |||
@@ -64,7 +63,7 @@ class Accuracy(Metric): | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | |||
f"{pred.shape[:-1]}, got {target.shape}.") | |||
if masks is not None: | |||
@@ -23,14 +23,14 @@ __all__ = [ | |||
"BucketedBatchSampler", | |||
"ReproducibleBatchSampler", | |||
"re_instantiate_sampler", | |||
"conversion_between_reproducible_and_unrepeated_sampler" | |||
"re_instantiate_sampler" | |||
] | |||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | |||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from .utils import re_instantiate_sampler | |||
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | |||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||
@@ -0,0 +1,33 @@ | |||
from fastNLP.core.samplers import re_instantiate_sampler | |||
from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \ | |||
SortedSampler | |||
from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \ | |||
UnrepeatedSequentialSampler, UnrepeatedSortedSampler | |||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
""" | |||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
ReproducibleSampler, | |||
:param sampler: | |||
:return: | |||
""" | |||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
if isinstance(sampler, UnrepeatedSampler): | |||
if isinstance(sampler, UnrepeatedRandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
else: | |||
if isinstance(sampler, RandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
elif isinstance(sampler, SequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
elif isinstance(sampler, SortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no reproducible version.") |
@@ -378,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
batch_indices = list(batch_indices[:-1]) | |||
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | |||
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | |||
batches = (np.array(batches)[batch_indices]).tolist() | |||
batches = (np.array(batches, dtype=object)[batch_indices]).tolist() | |||
if last_batches: | |||
batches = batches + last_batches | |||
return batches | |||
@@ -387,19 +387,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | |||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
" consumed. ") | |||
states = { | |||
'seed': self.seed, | |||
'epoch': self.epoch, | |||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, | |||
'num_batch_per_bucket': self.num_batch_per_bucket, | |||
'num_replicas': self.num_replicas | |||
} | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
'num_replicas': self.num_replicas, | |||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
@@ -1,3 +1,10 @@ | |||
__all__ = [ | |||
'ReproducibleSampler', | |||
'RandomSampler', | |||
"SortedSampler", | |||
"SequentialSampler" | |||
] | |||
from typing import Dict, List, Union | |||
import math | |||
import os | |||
@@ -10,13 +17,6 @@ from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
from .utils import NumConsumedSamplesArray | |||
__all__ = [ | |||
'ReproducibleSampler', | |||
'RandomSampler', | |||
"SortedSampler", | |||
"SequentialSampler" | |||
] | |||
class ReproducibleSampler: | |||
""" | |||
@@ -1,42 +1,10 @@ | |||
__all__ = [ | |||
're_instantiate_sampler', | |||
'conversion_between_reproducible_and_unrepeated_sampler' | |||
're_instantiate_sampler' | |||
] | |||
from array import array | |||
from typing import Sequence | |||
from collections import deque | |||
from fastNLP.core.samplers.unrepeated_sampler import * | |||
from fastNLP.core.samplers.reproducible_sampler import * | |||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
""" | |||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
ReproducibleSampler, | |||
:param sampler: | |||
:return: | |||
""" | |||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
if isinstance(sampler, UnrepeatedSampler): | |||
if isinstance(sampler, UnrepeatedRandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
else: | |||
if isinstance(sampler, RandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
elif isinstance(sampler, SequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
elif isinstance(sampler, SortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
all_attributes = vars(sampler) | |||
@@ -13,7 +13,6 @@ __all__ = [ | |||
'torch_paddle_move_data_to_device', | |||
'torch_move_data_to_device', | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
@@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||
from .rich_progress import f_rich_progress | |||
from .torch_paddle_utils import torch_paddle_move_data_to_device | |||
from .torch_utils import torch_move_data_to_device | |||
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | |||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||
@@ -1,3 +1,4 @@ | |||
import functools | |||
import inspect | |||
from inspect import Parameter | |||
import dataclasses | |||
@@ -24,10 +25,8 @@ from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||
__all__ = [ | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
@@ -54,30 +53,6 @@ def get_fn_arg_names(fn: Callable) -> List[str]: | |||
return list(inspect.signature(fn).parameters) | |||
def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||
r""" | |||
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||
:param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||
:param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||
:return: bool,表示传入的 `batch_step_fn` 是否正确; | |||
""" | |||
if fn is None: | |||
return True | |||
if not callable(fn): | |||
return False | |||
else: | |||
params = inspect.signature(fn).parameters | |||
not_default_params = {} | |||
for _name, _param in params.items(): | |||
if _param.default == Parameter.empty: | |||
not_default_params[_name] = _param | |||
return len(not_default_params) == param_num | |||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
r""" | |||
@@ -95,7 +70,6 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||
参数值后,再传给 `fn` 进行实际的运算; | |||
:param mapping: 一个字典,用来更改其前面的字典的键值; | |||
:param wo_auto_param_call: 是否关闭默认的参数匹配行为; | |||
:return: 返回 `fn` 运行的结果; | |||
@@ -123,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
_kwargs = None | |||
for _name, _param in _need_params.items(): | |||
if _param.kind == Parameter.VAR_POSITIONAL: | |||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.") | |||
if _param.kind == Parameter.VAR_KEYWORD: | |||
_kwargs = (_name, _param) | |||
@@ -136,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
_default_params[_name] = _param.default | |||
if mapping is not None: | |||
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \ | |||
f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
_has_params = {} | |||
duplicate_names = [] | |||
for arg in args: | |||
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||
if not isinstance(arg, Dict): | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
raise TypeError(f"Exception happens when calling {fn_msg}. " | |||
f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.") | |||
for _name, _value in arg.items(): | |||
if mapping is not None and _name in mapping: | |||
_name = mapping[_name] | |||
@@ -153,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
elif _name in _need_params and not (_has_params[_name] is _value): | |||
duplicate_names.append(_name) | |||
if duplicate_names: | |||
raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.") | |||
# 将具有默认值但是没有被输入修改过的参数值传进去; | |||
for _name, _value in _default_params.items(): | |||
@@ -162,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
if len(_has_params)<len(_need_params): | |||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | |||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||
_provided_keys = _get_keys(args) | |||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} " | |||
f"are not found in the input keys({_provided_keys}).") | |||
return fn(**_has_params) | |||
def _get_keys(args:List[Dict]) -> List[List[str]]: | |||
""" | |||
返回每个 dict 的 keys | |||
:param args: | |||
:return: | |||
""" | |||
_provided_keys = [] | |||
for arg in args: | |||
_provided_keys.append(list(arg.keys())) | |||
return _provided_keys | |||
def _get_fun_msg(fn)->str: | |||
""" | |||
获取函数的基本信息,帮助报错。 | |||
ex: | |||
print(_get_fun_msg(_get_fun_msg)) | |||
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | |||
:param callable fn: | |||
:return: | |||
""" | |||
if isinstance(fn, functools.partial): | |||
return _get_fun_msg(fn.func) | |||
try: | |||
fn_name = fn.__qualname__ + str(inspect.signature(fn)) | |||
except: | |||
fn_name = str(fn) | |||
try: | |||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||
except: | |||
fp = '' | |||
msg = f'`{fn_name}`' + fp | |||
return msg | |||
def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||
""" | |||
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 | |||
进行报错。 | |||
:param fn: 需要检测的函数,可以是 method 或者 function 。 | |||
:param expected_params: 期待应该支持的参数。 | |||
:param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。 | |||
:return: | |||
""" | |||
if fn_name is not None: | |||
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." | |||
parameters = list(inspect.signature(fn).parameters.values()) | |||
if inspect.ismethod(fn): | |||
if len(parameters)>0 and parameters[0].name == 'self': | |||
parameters = parameters[1:] # 去掉self | |||
no_var_param = True # 没有 * 这种参数 | |||
number_param_need_value = 0 | |||
for param in parameters: | |||
if param.kind is param.VAR_POSITIONAL: | |||
no_var_param = False | |||
elif param.kind is param.VAR_KEYWORD: | |||
no_var_param = False | |||
else: | |||
if param.default is param.empty: | |||
number_param_need_value += 1 | |||
if len(parameters)<len(expected_params) and no_var_param: | |||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||
f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||
if number_param_need_value>len(expected_params): | |||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||
f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||
def check_user_specific_params(user_params: Dict, fn: Callable): | |||
""" | |||
该函数使用用户的输入来对指定函数的参数进行赋值; | |||
@@ -592,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||
wait_to_success(path.exists) | |||
def get_class_that_defined_method(method): | |||
""" | |||
给定一个method,返回这个 method 的 class 的对象 | |||
:param method: | |||
:return: | |||
""" | |||
if isinstance(method, functools.partial): | |||
return get_class_that_defined_method(method.func) | |||
if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)): | |||
for cls in inspect.getmro(method.__self__.__class__): | |||
if method.__name__ in cls.__dict__: | |||
return cls | |||
method = getattr(method, '__func__', method) # fallback to __qualname__ parsing | |||
if inspect.isfunction(method): | |||
cls = getattr(inspect.getmodule(method), | |||
method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||
None) | |||
if isinstance(cls, type): | |||
return cls | |||
return getattr(method, '__objclass__', None) # handle special descriptor objects |
@@ -251,10 +251,10 @@ class DataBundle: | |||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | |||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 | |||
.. note:: | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
@@ -285,7 +285,7 @@ class DataBundle: | |||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||
对DataBundle中所有的dataset使用apply方法 | |||
@@ -309,10 +309,10 @@ class DataBundle: | |||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 | |||
.. note:: | |||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
@@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): | |||
def process_from_file(self, paths) -> DataBundle: | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
:param paths: | |||
:return: DataBundle | |||
@@ -164,7 +164,7 @@ class GraphBuilderBase: | |||
def build_graph_from_file(self, path: str): | |||
r""" | |||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
:param path: | |||
:return: scipy_sparse_matrix | |||
@@ -33,7 +33,7 @@ class Pipe: | |||
def process_from_file(self, paths: str) -> DataBundle: | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
:param str paths: | |||
:return: DataBundle | |||
@@ -0,0 +1,187 @@ | |||
from functools import partial | |||
import pytest | |||
from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg | |||
from fastNLP.core.metrics import Metric | |||
class TestAutoParamCall: | |||
def test_basic(self): | |||
def fn(x): | |||
return x | |||
x = {'x': 3, 'y': 4} | |||
r = auto_param_call(fn, x) | |||
assert r==3 | |||
xs = [] | |||
for i in range(10): | |||
xs.append({f'x{i}': i}) | |||
def fn(x0, x1, x2, x3): | |||
return x0 + x1 + x2 + x3 | |||
r = auto_param_call(fn, *xs) | |||
assert r == 0 + 1+ 2+ 3 | |||
def fn(chongfu1, chongfu2, buChongFu): | |||
pass | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2}, | |||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||
assert 'The following key present in several inputs' in exc_info.value.args[0] | |||
assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0] | |||
# 没用到不报错 | |||
def fn(chongfu1, buChongFu): | |||
pass | |||
auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2}, | |||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||
# 可以定制signature_fn | |||
def fn1(**kwargs): | |||
kwargs.pop('x') | |||
kwargs.pop('y') | |||
assert len(kwargs)==0 | |||
def fn(x, y): | |||
pass | |||
x = {'x': 3, 'y': 4} | |||
r = auto_param_call(fn1, x, signature_fn=fn) | |||
# 没提供的时候报错 | |||
def fn(meiti1, meiti2, tigong): | |||
pass | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(fn, {'tigong':1}) | |||
assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0] | |||
# 默认值替换 | |||
def fn(x, y=100): | |||
return x + y | |||
r = auto_param_call(fn, {'x': 10, 'y': 20}) | |||
assert r==30 | |||
assert auto_param_call(fn, {'x': 10, 'z': 20})==110 | |||
# 测试mapping的使用 | |||
def fn(x, y=100): | |||
return x + y | |||
r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'}) | |||
assert r==30 | |||
# 测试不需要任何参数 | |||
def fn(): | |||
return 1 | |||
assert 1 == auto_param_call(fn, {'x':1}) | |||
# 测试调用类的方法没问题 | |||
assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1}) | |||
assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'}) | |||
def test_msg(self): | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(self.call_this, {'x':1}) | |||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(call_this_for_auto_param_call, {'x':1}) | |||
assert __file__ in exc_info.value.args[0] | |||
assert 'call_this_for_auto_param_call' in exc_info.value.args[0] | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(self.call_this_two, {'x':1}) | |||
assert __file__ in exc_info.value.args[0] | |||
with pytest.raises(BaseException) as exc_info: | |||
auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this) | |||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息 | |||
def call_this(self, x, y): | |||
return x + y | |||
def call_this_two(self, x, y, z=pytest, **kwargs): | |||
return x + y | |||
def test_metric_auto_param_call(self): | |||
metric = AutoParamCallMetric() | |||
with pytest.raises(BaseException): | |||
auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__) | |||
class AutoParamCallMetric(Metric): | |||
def update(self, x): | |||
pass | |||
def call_this_for_auto_param_call(x, y): | |||
return x + y | |||
class TestCheckNumberOfParameters: | |||
def test_validate_every(self): | |||
def validate_every(trainer): | |||
pass | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
# 无默认值,多了报错 | |||
def validate_every(trainer, other): | |||
pass | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
assert "2 parameters" in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 有默认值ok | |||
def validate_every(trainer, other=1): | |||
pass | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
# 参数多了 | |||
def validate_every(trainer): | |||
pass | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||
assert "accepts 1 parameters" in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 使用partial | |||
def validate_every(trainer, other): | |||
pass | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||
assert 'accepts 2 parameters' in exc_info.value.args[0] | |||
print(exc_info.value.args[0]) | |||
# 如果存在 *args 或 *kwargs 不报错多的 | |||
def validate_every(trainer, *args): | |||
pass | |||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more']) | |||
def validate_every(trainer, **kwargs): | |||
pass | |||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
# class 的方法删掉self | |||
class InnerClass: | |||
def demo(self, x): | |||
pass | |||
def no_param(self): | |||
pass | |||
def param_kwargs(self, **kwargs): | |||
pass | |||
inner = InnerClass() | |||
with pytest.raises(RuntimeError) as exc_info: | |||
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||
assert 'accepts 1 parameters' in exc_info.value.args[0] | |||
_check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||
def test_get_fun_msg(): | |||
def demo(x): | |||
pass | |||
print(_get_fun_msg(_get_fun_msg)) |
@@ -2,37 +2,19 @@ import os | |||
import sys | |||
import __main__ | |||
from functools import wraps | |||
import inspect | |||
from inspect import ismethod | |||
import functools | |||
from copy import deepcopy | |||
from io import StringIO | |||
import time | |||
import numpy as np | |||
from fastNLP.core.utils.utils import get_class_that_defined_method | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.core.log import logger | |||
def get_class_that_defined_method(meth): | |||
if isinstance(meth, functools.partial): | |||
return get_class_that_defined_method(meth.func) | |||
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): | |||
for cls in inspect.getmro(meth.__self__.__class__): | |||
if meth.__name__ in cls.__dict__: | |||
return cls | |||
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing | |||
if inspect.isfunction(meth): | |||
cls = getattr(inspect.getmodule(meth), | |||
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||
None) | |||
if isinstance(cls, type): | |||
return cls | |||
return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||
def recover_logger(fn): | |||
@wraps(fn) | |||
def wrapper(*args, **kwargs): | |||