@@ -10,6 +10,7 @@ from .utils import _get_monitor_value | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
class Callback: | class Callback: | ||||
@@ -299,7 +300,11 @@ class HasMonitorCallback(Callback): | |||||
self.must_have_moinitor = must_have_monitor | self.must_have_moinitor = must_have_monitor | ||||
def set_monitor(self, monitor, larger_better): | def set_monitor(self, monitor, larger_better): | ||||
self.monitor = str(monitor) if monitor is not None else None | |||||
if callable(monitor): # 检查是否能够接受一个参数 | |||||
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') | |||||
self.monitor = monitor | |||||
else: | |||||
self.monitor = str(monitor) if monitor is not None else None | |||||
self.larger_better = bool(larger_better) | self.larger_better = bool(larger_better) | ||||
if larger_better: | if larger_better: | ||||
self.monitor_value = float('-inf') | self.monitor_value = float('-inf') | ||||
@@ -322,24 +327,33 @@ class HasMonitorCallback(Callback): | |||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | ||||
f"You can set it in the initialization or through Trainer.") | f"You can set it in the initialization or through Trainer.") | ||||
def get_monitor_value(self, results:Dict)->float: | |||||
def get_monitor_value(self, results:Dict)->Union[float, None]: | |||||
""" | """ | ||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | ||||
:param results: | :param results: | ||||
:return: | |||||
:return: 如果为 None ,表明此次没有找到合适的monitor | |||||
""" | """ | ||||
if len(results)==0: | if len(results)==0: | ||||
return 0 | |||||
return None | |||||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | # 保证所有的 tensor 都被转换为了 python 特定的类型 | ||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | ||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | ||||
real_monitor=self._real_monitor, | real_monitor=self._real_monitor, | ||||
res=results) | res=results) | ||||
if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||||
logger.warning( | |||||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
if monitor_value is None: | |||||
return monitor_value | |||||
# 第一次运行 | |||||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
# 检测到此次和上次不同。 | |||||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||||
f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||||
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||||
f"customized monitor function when the evaluation results are varying between validation.") | |||||
self._real_monitor = use_monitor | self._real_monitor = use_monitor | ||||
return monitor_value | return monitor_value | ||||
@@ -347,10 +361,12 @@ class HasMonitorCallback(Callback): | |||||
""" | """ | ||||
检测 monitor_value 是否是更好的 | 检测 monitor_value 是否是更好的 | ||||
:param monitor_value: | |||||
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False | |||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | ||||
:return: | :return: | ||||
""" | """ | ||||
if monitor_value is None: | |||||
return False | |||||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | ||||
if keep_if_better and better: | if keep_if_better and better: | ||||
self.monitor_value = monitor_value | self.monitor_value = monitor_value | ||||
@@ -364,6 +380,12 @@ class HasMonitorCallback(Callback): | |||||
:param monitor_value2: | :param monitor_value2: | ||||
:return: | :return: | ||||
""" | """ | ||||
if monitor_value1 is None and monitor_value2 is None: | |||||
return True | |||||
if monitor_value1 is None: | |||||
return False | |||||
if monitor_value2 is None: | |||||
return True | |||||
better = False | better = False | ||||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | if (self.larger_better and monitor_value1 > monitor_value2) or \ | ||||
(not self.larger_better and monitor_value1 < monitor_value2): | (not self.larger_better and monitor_value1 < monitor_value2): | ||||
@@ -10,8 +10,7 @@ from copy import deepcopy | |||||
import fastNLP | import fastNLP | ||||
from .callback import Callback, HasMonitorCallback | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from .callback import HasMonitorCallback | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | ||||
@@ -166,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback): | |||||
""" | """ | ||||
if self.save_topk is not None: | if self.save_topk is not None: | ||||
monitor_value = self.get_monitor_value(results=results) | monitor_value = self.get_monitor_value(results=results) | ||||
if monitor_value is None: | |||||
return | |||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | ||||
f"-{self._real_monitor}_{monitor_value}" | f"-{self._real_monitor}_{monitor_value}" | ||||
@@ -231,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -278,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||||
返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -2,17 +2,18 @@ __all__ = [ | |||||
'EarlyStopCallback' | 'EarlyStopCallback' | ||||
] | ] | ||||
from typing import Dict | |||||
from typing import Dict, Union, Callable | |||||
from .callback import HasMonitorCallback | from .callback import HasMonitorCallback | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | from fastNLP.core.utils.exceptions import EarlyStopException | ||||
class EarlyStopCallback(HasMonitorCallback): | class EarlyStopCallback(HasMonitorCallback): | ||||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||||
""" | """ | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param patience: 多少次 validate 不没有提升就停止。 | :param patience: 多少次 validate 不没有提升就停止。 | ||||
""" | """ | ||||
@@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
self.patience = patience | self.patience = patience | ||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | monitor_value = self.get_monitor_value(results) | ||||
if monitor_value is None: | |||||
return | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | if self.is_better_monitor_value(monitor_value, keep_if_better=True): | ||||
self.wait = 0 | self.wait = 0 | ||||
else: | else: | ||||
@@ -3,7 +3,7 @@ __all__ = [ | |||||
] | ] | ||||
import os | import os | ||||
from typing import Optional, Callable | |||||
from typing import Optional, Callable, Union | |||||
from .callback import HasMonitorCallback | from .callback import HasMonitorCallback | ||||
from io import BytesIO | from io import BytesIO | ||||
import shutil | import shutil | ||||
@@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call | |||||
class LoadBestModelCallback(HasMonitorCallback): | class LoadBestModelCallback(HasMonitorCallback): | ||||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True, | |||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | ||||
model_load_fn:Optional[Callable] = None, | model_load_fn:Optional[Callable] = None, | ||||
delete_after_train:bool = True): | delete_after_train:bool = True): | ||||
""" | """ | ||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | ||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||||
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
self.get_monitor_value(sanity_check_res) | self.get_monitor_value(sanity_check_res) | ||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | monitor_value = self.get_monitor_value(results) | ||||
if monitor_value is None: | |||||
return | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | if self.is_better_monitor_value(monitor_value, keep_if_better=True): | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
@@ -45,6 +45,7 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | ||||
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||||
字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -1,9 +1,10 @@ | |||||
from typing import Optional | |||||
from typing import Optional, Union | |||||
from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
from difflib import SequenceMatcher | from difflib import SequenceMatcher | ||||
from fastNLP.core.utils.utils import _get_fun_msg | |||||
def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): | |||||
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): | |||||
""" | """ | ||||
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | ||||
匹配。 | 匹配。 | ||||
@@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||||
:param monitor: | :param monitor: | ||||
:param real_monitor: | :param real_monitor: | ||||
:param res: | :param res: | ||||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value | |||||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | |||||
找到对应的 monitor | |||||
""" | """ | ||||
if len(res)==0: | if len(res)==0: | ||||
return monitor, 0 | |||||
return monitor, None | |||||
if callable(monitor): | |||||
try: | |||||
monitor_value = monitor(res) | |||||
except BaseException as e: | |||||
logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.") | |||||
raise e | |||||
return monitor, monitor_value | |||||
if monitor in res: | if monitor in res: | ||||
return monitor, res[monitor] | return monitor, res[monitor] | ||||
@@ -5,7 +5,7 @@ __all__ = [ | |||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from typing import Any, Dict, List, Callable, Union | |||||
from typing import Any, Dict, List, Callable, Union, Tuple | |||||
from numbers import Number | from numbers import Number | ||||
import warnings | import warnings | ||||
@@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception): | |||||
self.field_name = field_name # 标示当前 field 的名称 | self.field_name = field_name # 标示当前 field 的名称 | ||||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||||
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||||
r""" | r""" | ||||
识别cell的类别与dimension的数量 | 识别cell的类别与dimension的数量 | ||||
@@ -206,7 +206,7 @@ class AutoCollator(Collator): | |||||
def __init__(self, as_numpy: bool): | def __init__(self, as_numpy: bool): | ||||
super(AutoCollator, self).__init__() | super(AutoCollator, self).__init__() | ||||
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | ||||
self.need_inputs = [] # 需要的 field name | |||||
self.need_inputs = set() # 需要的 field name | |||||
self.field_dtypes = None # 每列数据单元的 dtype 类型 | self.field_dtypes = None # 每列数据单元的 dtype 类型 | ||||
self.field_dims = None # 每列数据单元维度 | self.field_dims = None # 每列数据单元维度 | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
@@ -214,10 +214,17 @@ class AutoCollator(Collator): | |||||
def __call__(self, ins_lst: List[Dict]) -> dict: | def __call__(self, ins_lst: List[Dict]) -> dict: | ||||
if len(self.need_inputs) == 0: | if len(self.need_inputs) == 0: | ||||
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | ||||
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||||
# 第一种情况,设置了 set_input 的值 | # 第一种情况,设置了 set_input 的值 | ||||
# 第二种情况, 根据数据的类型的判断是否 padding | # 第二种情况, 根据数据的类型的判断是否 padding | ||||
if self.field_dtypes is None and self.field_dims is None: | if self.field_dtypes is None and self.field_dims is None: | ||||
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) | |||||
field_dtypes, field_dims = {}, {} | |||||
for key, value in ins_lst[0].items(): | |||||
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||||
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||||
self.field_dtypes = field_dtypes | |||||
self.field_dims = field_dims | |||||
pack_ins_lst, pad_ins_lst = {field_name: [] | pack_ins_lst, pad_ins_lst = {field_name: [] | ||||
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | ||||
@@ -233,13 +240,13 @@ class AutoCollator(Collator): | |||||
if len(self.pad_field_value.keys()) > 0: | if len(self.pad_field_value.keys()) > 0: | ||||
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | ||||
drop_field_names = [] | |||||
non_pad_field_names = [] | |||||
for k, v in self.pad_field_value.items(): | for k, v in self.pad_field_value.items(): | ||||
if v is None: | if v is None: | ||||
drop_field_names.append(k) | |||||
non_pad_field_names.append(k) | |||||
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | ||||
for field_name in drop_field_names: | |||||
for field_name in non_pad_field_names: | |||||
field_array = pack_ins_lst.pop(field_name) | field_array = pack_ins_lst.pop(field_name) | ||||
pad_ins_lst[field_name] = np.array(field_array) | pad_ins_lst[field_name] = np.array(field_array) | ||||
@@ -269,7 +276,7 @@ class AutoCollator(Collator): | |||||
def set_input(self, *field_names): | def set_input(self, *field_names): | ||||
for field_name in field_names: | for field_name in field_names: | ||||
self.need_inputs.append(field_name) | |||||
self.need_inputs.add(field_name) | |||||
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | ||||
@@ -11,11 +11,12 @@ __all__ = [ | |||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from .loops import Loop, EvaluateBatchLoop | from .loops import Loop, EvaluateBatchLoop | ||||
from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ | |||||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | |||||
match_and_substitute_params, f_rich_progress | match_and_substitute_params, f_rich_progress | ||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | ||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -38,11 +39,11 @@ class Evaluator: | |||||
driver: Union[str, Driver] = 'single', | driver: Union[str, Driver] = 'single', | ||||
device: Optional[Union[int, List[int], str]] = None, | device: Optional[Union[int, List[int], str]] = None, | ||||
batch_step_fn: Optional[callable] = None, | batch_step_fn: Optional[callable] = None, | ||||
mode: str = "validate", | |||||
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
model_wo_auto_param_call: bool = False, | model_wo_auto_param_call: bool = False, | ||||
fp16: Optional[bool] = False, | |||||
fp16: bool = False, | |||||
verbose: int = 1, | verbose: int = 1, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -92,8 +93,8 @@ class Evaluator: | |||||
self.device = device | self.device = device | ||||
self.verbose = verbose | self.verbose = verbose | ||||
assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ | |||||
"two parameters." | |||||
if batch_step_fn is not None: | |||||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
self.mode = mode | self.mode = mode | ||||
@@ -135,6 +136,7 @@ class Evaluator: | |||||
if self.progress_bar == 'auto': | if self.progress_bar == 'auto': | ||||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | ||||
self.driver.check_evaluator_mode(self.mode) | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | ||||
@@ -154,8 +156,6 @@ class Evaluator: | |||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | ||||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | ||||
self.driver.check_evaluator_mode(self.mode) | |||||
if self.mode == 'validate': | if self.mode == 'validate': | ||||
assert self.driver.has_validate_dataloaders() | assert self.driver.has_validate_dataloaders() | ||||
else: | else: | ||||
@@ -367,9 +367,10 @@ class _MetricsWrapper: | |||||
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | ||||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | f" return a dict from your model or use `output_mapping` to convert it into dict type.") | ||||
if isinstance(metric, Metric): | if isinstance(metric, Metric): | ||||
auto_param_call(metric.update, outputs, *args) | |||||
# 这样在 auto_param_call 报错的时候才清晰。 | |||||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||||
elif _is_torchmetrics_metric(metric): | elif _is_torchmetrics_metric(metric): | ||||
auto_param_call(metric.update, outputs, *args) | |||||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||||
elif _is_allennlp_metric(metric): | elif _is_allennlp_metric(metric): | ||||
auto_param_call(metric.__call__, outputs, *args) | auto_param_call(metric.__call__, outputs, *args) | ||||
elif _is_paddle_metric(metric): | elif _is_paddle_metric(metric): | ||||
@@ -14,6 +14,7 @@ __all__ = [ | |||||
from .loops import Loop, TrainBatchLoop | from .loops import Loop, TrainBatchLoop | ||||
from .utils import State, TrainerState | from .utils import State, TrainerState | ||||
from .utils.utils import check_validate_every | |||||
from .evaluator import Evaluator | from .evaluator import Evaluator | ||||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | ||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | ||||
@@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | |||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME | ||||
@@ -42,7 +44,7 @@ class Trainer(TrainerEventTrigger): | |||||
validate_dataloaders=None, | validate_dataloaders=None, | ||||
batch_step_fn: Optional[Callable] = None, | batch_step_fn: Optional[Callable] = None, | ||||
validate_batch_step_fn: Optional[Callable] = None, | validate_batch_step_fn: Optional[Callable] = None, | ||||
validate_mode: str = "validate", | |||||
validate_mode: Union[str, callable] = 'validate', | |||||
callbacks: Union[List[Callback], Callback, None] = None, | callbacks: Union[List[Callback], Callback, None] = None, | ||||
metrics: Optional[dict] = None, | metrics: Optional[dict] = None, | ||||
validate_every: Optional[Union[int, callable]] = -1, | validate_every: Optional[Union[int, callable]] = -1, | ||||
@@ -51,7 +53,7 @@ class Trainer(TrainerEventTrigger): | |||||
model_wo_auto_param_call: bool = False, | model_wo_auto_param_call: bool = False, | ||||
accumulation_steps: int = 1, | accumulation_steps: int = 1, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
monitor: str = None, | |||||
monitor: Union[str, callable] = None, | |||||
larger_better: bool = True, | larger_better: bool = True, | ||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
**kwargs | **kwargs | ||||
@@ -90,11 +92,8 @@ class Trainer(TrainerEventTrigger): | |||||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | ||||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | ||||
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | :param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | ||||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 | |||||
中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed | |||||
表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 | |||||
例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch | |||||
之后,每个 epoch 结束进行一次 validate 。 | |||||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | ||||
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | ||||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | ||||
@@ -111,7 +110,7 @@ class Trainer(TrainerEventTrigger): | |||||
:param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | ||||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。 | |||||
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
:param kwargs: 一些其它的可能需要的参数; | :param kwargs: 一些其它的可能需要的参数; | ||||
@@ -142,10 +141,9 @@ class Trainer(TrainerEventTrigger): | |||||
self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ | |||||
"two parameters." | |||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||||
self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | ||||
else: | else: | ||||
self.check_batch_step_fn = lambda *args, **kwargs: ... | self.check_batch_step_fn = lambda *args, **kwargs: ... | ||||
@@ -221,18 +219,11 @@ class Trainer(TrainerEventTrigger): | |||||
if metrics is not None and validate_dataloaders is None: | if metrics is not None and validate_dataloaders is None: | ||||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | ||||
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | |||||
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | |||||
self.evaluator = None | self.evaluator = None | ||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and validate_dataloaders is not None: | if metrics is not None and validate_dataloaders is not None: | ||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||||
if callable(validate_every): | |||||
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||||
"and in this way, the kind of controlling frequency is depending on the 'step'.") | |||||
check_validate_every(validate_every) | |||||
self.evaluator = Evaluator( | self.evaluator = Evaluator( | ||||
model=model, | model=model, | ||||
dataloaders=validate_dataloaders, | dataloaders=validate_dataloaders, | ||||
@@ -352,33 +343,32 @@ class Trainer(TrainerEventTrigger): | |||||
_validate_res: dict = validate_fn() | _validate_res: dict = validate_fn() | ||||
trainer.on_validate_end(_validate_res) | trainer.on_validate_end(_validate_res) | ||||
self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
def step_validate(self): | def step_validate(self): | ||||
if self.evaluator is not None: | |||||
should_run_validate = False | |||||
""" | |||||
在每个 batch 结束后调用,根据设置执行 evaluate 。 | |||||
:return: | |||||
""" | |||||
if self.evaluator is not None: | |||||
if callable(self.validate_every): | if callable(self.validate_every): | ||||
if self.validate_every(self): | if self.validate_every(self): | ||||
should_run_validate = True | |||||
elif self.validate_every > 0: | |||||
if self.global_forward_batches % self.validate_every == 0: | |||||
should_run_validate = True | |||||
if should_run_validate: | |||||
self.validate_fn() | |||||
self.run_evaluate() | |||||
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||||
self.run_evaluate() | |||||
def epoch_validate(self): | def epoch_validate(self): | ||||
if self.evaluator is not None: | |||||
should_run_validate = False | |||||
""" | |||||
在每个 epoch 结束后调用,根据设置执行 evaluate 。 | |||||
:return: | |||||
""" | |||||
if self.evaluator is not None: | |||||
if isinstance(self.validate_every, int) and self.validate_every < 0: | if isinstance(self.validate_every, int) and self.validate_every < 0: | ||||
validate_every = -self.validate_every | validate_every = -self.validate_every | ||||
if self.cur_epoch_idx % validate_every == 0: | if self.cur_epoch_idx % validate_every == 0: | ||||
should_run_validate = True | |||||
if should_run_validate: | |||||
self.validate_fn() | |||||
self.run_evaluate() | |||||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | ||||
r""" | r""" | ||||
@@ -410,9 +400,7 @@ class Trainer(TrainerEventTrigger): | |||||
def wrapper(fn: Callable) -> Callable: | def wrapper(fn: Callable) -> Callable: | ||||
cls._custom_callbacks[marker].append((event, fn)) | cls._custom_callbacks[marker].append((event, fn)) | ||||
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | ||||
assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ | |||||
f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ | |||||
f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." | |||||
_check_valid_parameters_number(fn, callback_fn_args) | |||||
return fn | return fn | ||||
return wrapper | return wrapper | ||||
@@ -1,8 +1,9 @@ | |||||
from collections.abc import Iterator | |||||
import inspect | |||||
from typing import Dict | from typing import Dict | ||||
from fastNLP.core.callbacks import CallbackManager | from fastNLP.core.callbacks import CallbackManager | ||||
from .state import TrainerState | from .state import TrainerState | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||||
class TrainerEventTrigger: | class TrainerEventTrigger: | ||||
@@ -125,5 +126,8 @@ class _TruncatedDataLoader: | |||||
return getattr(self.dataloader, item) | return getattr(self.dataloader, item) | ||||
def check_validate_every(validate_every): | |||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||||
if callable(validate_every): | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) |
@@ -178,10 +178,11 @@ class DataSet: | |||||
elif isinstance(idx, slice): | elif isinstance(idx, slice): | ||||
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | ||||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | ||||
data_set = DataSet() | |||||
dataset = DataSet() | |||||
for field_name, field in self.field_arrays.items(): | for field_name, field in self.field_arrays.items(): | ||||
data_set.add_field(field_name=field_name, fields=field.content[idx]) | |||||
return data_set | |||||
dataset.add_field(field_name=field_name, fields=field.content[idx]) | |||||
dataset.collate_fns = deepcopy(self.collate_fns) | |||||
return dataset | |||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
raise KeyError("No such field called {} in DataSet.".format(idx)) | raise KeyError("No such field called {} in DataSet.".format(idx)) | ||||
@@ -192,6 +193,7 @@ class DataSet: | |||||
assert isinstance(i, int), "Only int index allowed." | assert isinstance(i, int), "Only int index allowed." | ||||
instance = self[i] | instance = self[i] | ||||
dataset.append(instance) | dataset.append(instance) | ||||
dataset.collate_fns = deepcopy(self.collate_fns) | |||||
return dataset | return dataset | ||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
@@ -674,6 +676,8 @@ class DataSet: | |||||
dev_set.append(self[idx]) | dev_set.append(self[idx]) | ||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
dev_set.collate_fns = deepcopy(self.collate_fns) | |||||
train_set.collate_fns = deepcopy(self.collate_fns) | |||||
return dev_set, train_set | return dev_set, train_set | ||||
@@ -795,7 +799,7 @@ class DataSet: | |||||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | ||||
:return: | :return: | ||||
""" | """ | ||||
# TODO 需要去重复 | |||||
# TODO 不能为空 | |||||
for field_name in field_names: | for field_name in field_names: | ||||
self.collate_fns.set_pad_val(field_name, val=val) | self.collate_fns.set_pad_val(field_name, val=val) | ||||
@@ -66,7 +66,7 @@ class JittorDriver(Driver): | |||||
if mode == "validate": | if mode == "validate": | ||||
if not hasattr(model, "validate_step"): | if not hasattr(model, "validate_step"): | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | |||||
logger.warning_once( | |||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | "Your model does not have 'validate_step' method but has 'test_step' method, but you" | ||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | "are using 'mode=validate', we are going to use 'test_step' to substitute for" | ||||
"'validate_step'.") | "'validate_step'.") | ||||
@@ -74,7 +74,7 @@ class JittorDriver(Driver): | |||||
else: | else: | ||||
if not hasattr(model, "test_step"): | if not hasattr(model, "test_step"): | ||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | "are using 'mode=test', we are going to use 'validate_step' to substitute for" | ||||
"'test_step'.") | "'test_step'.") | ||||
@@ -133,7 +133,7 @@ class PaddleDriver(Driver): | |||||
else: | else: | ||||
if not hasattr(model, "test_step"): | if not hasattr(model, "test_step"): | ||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | "are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | ||||
"'test_step'.") | "'test_step'.") | ||||
@@ -333,10 +333,8 @@ def all_gather_object(object_list, obj, group=None): | |||||
>>> output | >>> output | ||||
['foo', 12, {1: 2}] | ['foo', 12, {1: 2}] | ||||
""" | """ | ||||
if dist._rank_not_in_group(group): | |||||
if dist.distributed_c10d._rank_not_in_group(group): | |||||
return | return | ||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
if _TORCH_GREATER_EQUAL_1_8: | if _TORCH_GREATER_EQUAL_1_8: | ||||
current_device = torch.device("cpu") | current_device = torch.device("cpu") | ||||
is_nccl_backend = _check_for_nccl_backend(group) | is_nccl_backend = _check_for_nccl_backend(group) | ||||
@@ -345,10 +343,11 @@ def all_gather_object(object_list, obj, group=None): | |||||
# We cannot simply use my_rank since rank == device is not necessarily | # We cannot simply use my_rank since rank == device is not necessarily | ||||
# true. | # true. | ||||
current_device = torch.device("cuda", torch.cuda.current_device()) | current_device = torch.device("cuda", torch.cuda.current_device()) | ||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
else: | else: | ||||
current_device = torch.cuda.current_device() | current_device = torch.cuda.current_device() | ||||
input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | # Gather all local sizes. This is so that we can find the max size, and index | ||||
# until the correct size when deserializing the tensors. | # until the correct size when deserializing the tensors. | ||||
group_size = dist.get_world_size(group=group) | group_size = dist.get_world_size(group=group) | ||||
@@ -379,3 +378,4 @@ def all_gather_object(object_list, obj, group=None): | |||||
tensor = tensor.cpu() | tensor = tensor.cpu() | ||||
tensor_size = object_size_list[i] | tensor_size = object_size_list[i] | ||||
object_list[i] = _tensor_to_object(tensor, tensor_size) | object_list[i] = _tensor_to_object(tensor, tensor_size) | ||||
return object_list |
@@ -113,7 +113,7 @@ class TorchDriver(Driver): | |||||
if mode == "validate": | if mode == "validate": | ||||
if not hasattr(model, "validate_step"): | if not hasattr(model, "validate_step"): | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | |||||
logger.warning_once( | |||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | "Your model does not have 'validate_step' method but has 'test_step' method, but you" | ||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | "are using 'mode=validate', we are going to use 'test_step' to substitute for" | ||||
"'validate_step'.") | "'validate_step'.") | ||||
@@ -125,9 +125,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
self._warning_msgs.add(msg) | self._warning_msgs.add(msg) | ||||
def warn(self, msg, *args, **kwargs): | def warn(self, msg, *args, **kwargs): | ||||
warnings.warn("The 'warn' method is deprecated, " | |||||
"use 'warning' instead", DeprecationWarning, 2) | |||||
self.warning(msg, *args, **kwargs) | |||||
if self.isEnabledFor(WARNING): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(WARNING, msg, args, **kwargs) | |||||
def error(self, msg, *args, **kwargs): | def error(self, msg, *args, **kwargs): | ||||
""" | """ | ||||
@@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask | |||||
class Accuracy(Metric): | class Accuracy(Metric): | ||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||||
aggregate_when_get_metric: bool = True): | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | ||||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | ||||
@@ -64,7 +63,7 @@ class Accuracy(Metric): | |||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | ||||
else: | else: | ||||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||||
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | |||||
f"{pred.shape[:-1]}, got {target.shape}.") | f"{pred.shape[:-1]}, got {target.shape}.") | ||||
if masks is not None: | if masks is not None: | ||||
@@ -23,14 +23,14 @@ __all__ = [ | |||||
"BucketedBatchSampler", | "BucketedBatchSampler", | ||||
"ReproducibleBatchSampler", | "ReproducibleBatchSampler", | ||||
"re_instantiate_sampler", | |||||
"conversion_between_reproducible_and_unrepeated_sampler" | |||||
"re_instantiate_sampler" | |||||
] | ] | ||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | ||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | ||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | ||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | ||||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from .utils import re_instantiate_sampler | |||||
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | |||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | ||||
@@ -0,0 +1,33 @@ | |||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \ | |||||
SortedSampler | |||||
from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \ | |||||
UnrepeatedSequentialSampler, UnrepeatedSortedSampler | |||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||||
""" | |||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||||
ReproducibleSampler, | |||||
:param sampler: | |||||
:return: | |||||
""" | |||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||||
if isinstance(sampler, UnrepeatedSampler): | |||||
if isinstance(sampler, UnrepeatedRandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||||
else: | |||||
if isinstance(sampler, RandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||||
elif isinstance(sampler, SequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||||
elif isinstance(sampler, SortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no reproducible version.") |
@@ -378,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
batch_indices = list(batch_indices[:-1]) | batch_indices = list(batch_indices[:-1]) | ||||
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | ||||
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | ||||
batches = (np.array(batches)[batch_indices]).tolist() | |||||
batches = (np.array(batches, dtype=object)[batch_indices]).tolist() | |||||
if last_batches: | if last_batches: | ||||
batches = batches + last_batches | batches = batches + last_batches | ||||
return batches | return batches | ||||
@@ -387,19 +387,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | ||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
" consumed. ") | " consumed. ") | ||||
states = { | |||||
'seed': self.seed, | |||||
'epoch': self.epoch, | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | |||||
'num_batch_per_bucket': self.num_batch_per_bucket, | |||||
'num_replicas': self.num_replicas | |||||
} | |||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||||
'num_replicas': self.num_replicas, | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
@@ -1,3 +1,10 @@ | |||||
__all__ = [ | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | |||||
"SortedSampler", | |||||
"SequentialSampler" | |||||
] | |||||
from typing import Dict, List, Union | from typing import Dict, List, Union | ||||
import math | import math | ||||
import os | import os | ||||
@@ -10,13 +17,6 @@ from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
from .utils import NumConsumedSamplesArray | from .utils import NumConsumedSamplesArray | ||||
__all__ = [ | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | |||||
"SortedSampler", | |||||
"SequentialSampler" | |||||
] | |||||
class ReproducibleSampler: | class ReproducibleSampler: | ||||
""" | """ | ||||
@@ -1,42 +1,10 @@ | |||||
__all__ = [ | __all__ = [ | ||||
're_instantiate_sampler', | |||||
'conversion_between_reproducible_and_unrepeated_sampler' | |||||
're_instantiate_sampler' | |||||
] | ] | ||||
from array import array | from array import array | ||||
from typing import Sequence | from typing import Sequence | ||||
from collections import deque | from collections import deque | ||||
from fastNLP.core.samplers.unrepeated_sampler import * | |||||
from fastNLP.core.samplers.reproducible_sampler import * | |||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||||
""" | |||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||||
ReproducibleSampler, | |||||
:param sampler: | |||||
:return: | |||||
""" | |||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||||
if isinstance(sampler, UnrepeatedSampler): | |||||
if isinstance(sampler, UnrepeatedRandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||||
else: | |||||
if isinstance(sampler, RandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||||
elif isinstance(sampler, SequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||||
elif isinstance(sampler, SortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||||
def re_instantiate_sampler(sampler, new_sampler_class=None): | def re_instantiate_sampler(sampler, new_sampler_class=None): | ||||
all_attributes = vars(sampler) | all_attributes = vars(sampler) | ||||
@@ -13,7 +13,6 @@ __all__ = [ | |||||
'torch_paddle_move_data_to_device', | 'torch_paddle_move_data_to_device', | ||||
'torch_move_data_to_device', | 'torch_move_data_to_device', | ||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
'check_fn_not_empty_params', | |||||
'auto_param_call', | 'auto_param_call', | ||||
'check_user_specific_params', | 'check_user_specific_params', | ||||
'dataclass_to_dict', | 'dataclass_to_dict', | ||||
@@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
from .torch_paddle_utils import torch_paddle_move_data_to_device | from .torch_paddle_utils import torch_paddle_move_data_to_device | ||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | |||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | ||||
@@ -1,3 +1,4 @@ | |||||
import functools | |||||
import inspect | import inspect | ||||
from inspect import Parameter | from inspect import Parameter | ||||
import dataclasses | import dataclasses | ||||
@@ -24,10 +25,8 @@ from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | from fastNLP.envs import FASTNLP_GLOBAL_RANK | ||||
__all__ = [ | __all__ = [ | ||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
'check_fn_not_empty_params', | |||||
'auto_param_call', | 'auto_param_call', | ||||
'check_user_specific_params', | 'check_user_specific_params', | ||||
'dataclass_to_dict', | 'dataclass_to_dict', | ||||
@@ -54,30 +53,6 @@ def get_fn_arg_names(fn: Callable) -> List[str]: | |||||
return list(inspect.signature(fn).parameters) | return list(inspect.signature(fn).parameters) | ||||
def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||||
r""" | |||||
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||||
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||||
:param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||||
:param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||||
:return: bool,表示传入的 `batch_step_fn` 是否正确; | |||||
""" | |||||
if fn is None: | |||||
return True | |||||
if not callable(fn): | |||||
return False | |||||
else: | |||||
params = inspect.signature(fn).parameters | |||||
not_default_params = {} | |||||
for _name, _param in params.items(): | |||||
if _param.default == Parameter.empty: | |||||
not_default_params[_name] = _param | |||||
return len(not_default_params) == param_num | |||||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | ||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | ||||
r""" | r""" | ||||
@@ -95,7 +70,6 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | ||||
参数值后,再传给 `fn` 进行实际的运算; | 参数值后,再传给 `fn` 进行实际的运算; | ||||
:param mapping: 一个字典,用来更改其前面的字典的键值; | :param mapping: 一个字典,用来更改其前面的字典的键值; | ||||
:param wo_auto_param_call: 是否关闭默认的参数匹配行为; | |||||
:return: 返回 `fn` 运行的结果; | :return: 返回 `fn` 运行的结果; | ||||
@@ -123,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
_kwargs = None | _kwargs = None | ||||
for _name, _param in _need_params.items(): | for _name, _param in _need_params.items(): | ||||
if _param.kind == Parameter.VAR_POSITIONAL: | if _param.kind == Parameter.VAR_POSITIONAL: | ||||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.") | |||||
if _param.kind == Parameter.VAR_KEYWORD: | if _param.kind == Parameter.VAR_KEYWORD: | ||||
_kwargs = (_name, _param) | _kwargs = (_name, _param) | ||||
@@ -136,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
_default_params[_name] = _param.default | _default_params[_name] = _param.default | ||||
if mapping is not None: | if mapping is not None: | ||||
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \ | |||||
f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||||
_has_params = {} | _has_params = {} | ||||
duplicate_names = [] | duplicate_names = [] | ||||
for arg in args: | for arg in args: | ||||
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||||
if not isinstance(arg, Dict): | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
raise TypeError(f"Exception happens when calling {fn_msg}. " | |||||
f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.") | |||||
for _name, _value in arg.items(): | for _name, _value in arg.items(): | ||||
if mapping is not None and _name in mapping: | if mapping is not None and _name in mapping: | ||||
_name = mapping[_name] | _name = mapping[_name] | ||||
@@ -153,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
elif _name in _need_params and not (_has_params[_name] is _value): | elif _name in _need_params and not (_has_params[_name] is _value): | ||||
duplicate_names.append(_name) | duplicate_names.append(_name) | ||||
if duplicate_names: | if duplicate_names: | ||||
raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.") | |||||
# 将具有默认值但是没有被输入修改过的参数值传进去; | # 将具有默认值但是没有被输入修改过的参数值传进去; | ||||
for _name, _value in _default_params.items(): | for _name, _value in _default_params.items(): | ||||
@@ -162,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
if len(_has_params)<len(_need_params): | if len(_has_params)<len(_need_params): | ||||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | ||||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||||
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn) | |||||
_provided_keys = _get_keys(args) | |||||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} " | |||||
f"are not found in the input keys({_provided_keys}).") | |||||
return fn(**_has_params) | return fn(**_has_params) | ||||
def _get_keys(args:List[Dict]) -> List[List[str]]: | |||||
""" | |||||
返回每个 dict 的 keys | |||||
:param args: | |||||
:return: | |||||
""" | |||||
_provided_keys = [] | |||||
for arg in args: | |||||
_provided_keys.append(list(arg.keys())) | |||||
return _provided_keys | |||||
def _get_fun_msg(fn)->str: | |||||
""" | |||||
获取函数的基本信息,帮助报错。 | |||||
ex: | |||||
print(_get_fun_msg(_get_fun_msg)) | |||||
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py) | |||||
:param callable fn: | |||||
:return: | |||||
""" | |||||
if isinstance(fn, functools.partial): | |||||
return _get_fun_msg(fn.func) | |||||
try: | |||||
fn_name = fn.__qualname__ + str(inspect.signature(fn)) | |||||
except: | |||||
fn_name = str(fn) | |||||
try: | |||||
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')' | |||||
except: | |||||
fp = '' | |||||
msg = f'`{fn_name}`' + fp | |||||
return msg | |||||
def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||||
""" | |||||
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 | |||||
进行报错。 | |||||
:param fn: 需要检测的函数,可以是 method 或者 function 。 | |||||
:param expected_params: 期待应该支持的参数。 | |||||
:param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。 | |||||
:return: | |||||
""" | |||||
if fn_name is not None: | |||||
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." | |||||
parameters = list(inspect.signature(fn).parameters.values()) | |||||
if inspect.ismethod(fn): | |||||
if len(parameters)>0 and parameters[0].name == 'self': | |||||
parameters = parameters[1:] # 去掉self | |||||
no_var_param = True # 没有 * 这种参数 | |||||
number_param_need_value = 0 | |||||
for param in parameters: | |||||
if param.kind is param.VAR_POSITIONAL: | |||||
no_var_param = False | |||||
elif param.kind is param.VAR_KEYWORD: | |||||
no_var_param = False | |||||
else: | |||||
if param.default is param.empty: | |||||
number_param_need_value += 1 | |||||
if len(parameters)<len(expected_params) and no_var_param: | |||||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||||
f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||||
if number_param_need_value>len(expected_params): | |||||
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||||
f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||||
def check_user_specific_params(user_params: Dict, fn: Callable): | def check_user_specific_params(user_params: Dict, fn: Callable): | ||||
""" | """ | ||||
该函数使用用户的输入来对指定函数的参数进行赋值; | 该函数使用用户的输入来对指定函数的参数进行赋值; | ||||
@@ -592,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||||
wait_to_success(path.exists) | wait_to_success(path.exists) | ||||
def get_class_that_defined_method(method): | |||||
""" | |||||
给定一个method,返回这个 method 的 class 的对象 | |||||
:param method: | |||||
:return: | |||||
""" | |||||
if isinstance(method, functools.partial): | |||||
return get_class_that_defined_method(method.func) | |||||
if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)): | |||||
for cls in inspect.getmro(method.__self__.__class__): | |||||
if method.__name__ in cls.__dict__: | |||||
return cls | |||||
method = getattr(method, '__func__', method) # fallback to __qualname__ parsing | |||||
if inspect.isfunction(method): | |||||
cls = getattr(inspect.getmodule(method), | |||||
method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||||
None) | |||||
if isinstance(cls, type): | |||||
return cls | |||||
return getattr(method, '__objclass__', None) # handle special descriptor objects |
@@ -251,10 +251,10 @@ class DataBundle: | |||||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | ||||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | ||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 | |||||
.. note:: | .. note:: | ||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | ``apply`` 区别的介绍。 | ||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
@@ -285,7 +285,7 @@ class DataBundle: | |||||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | ||||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | ||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||||
对DataBundle中所有的dataset使用apply方法 | 对DataBundle中所有的dataset使用apply方法 | ||||
@@ -309,10 +309,10 @@ class DataBundle: | |||||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | ||||
progress_desc: str = '', show_progress_bar: bool = True): | progress_desc: str = '', show_progress_bar: bool = True): | ||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 | |||||
.. note:: | .. note:: | ||||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | ``apply`` 区别的介绍。 | ||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
@@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): | |||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
r""" | r""" | ||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||||
:param paths: | :param paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -164,7 +164,7 @@ class GraphBuilderBase: | |||||
def build_graph_from_file(self, path: str): | def build_graph_from_file(self, path: str): | ||||
r""" | r""" | ||||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||||
:param path: | :param path: | ||||
:return: scipy_sparse_matrix | :return: scipy_sparse_matrix | ||||
@@ -33,7 +33,7 @@ class Pipe: | |||||
def process_from_file(self, paths: str) -> DataBundle: | def process_from_file(self, paths: str) -> DataBundle: | ||||
r""" | r""" | ||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||||
:param str paths: | :param str paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -0,0 +1,187 @@ | |||||
from functools import partial | |||||
import pytest | |||||
from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg | |||||
from fastNLP.core.metrics import Metric | |||||
class TestAutoParamCall: | |||||
def test_basic(self): | |||||
def fn(x): | |||||
return x | |||||
x = {'x': 3, 'y': 4} | |||||
r = auto_param_call(fn, x) | |||||
assert r==3 | |||||
xs = [] | |||||
for i in range(10): | |||||
xs.append({f'x{i}': i}) | |||||
def fn(x0, x1, x2, x3): | |||||
return x0 + x1 + x2 + x3 | |||||
r = auto_param_call(fn, *xs) | |||||
assert r == 0 + 1+ 2+ 3 | |||||
def fn(chongfu1, chongfu2, buChongFu): | |||||
pass | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2}, | |||||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||||
assert 'The following key present in several inputs' in exc_info.value.args[0] | |||||
assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0] | |||||
# 没用到不报错 | |||||
def fn(chongfu1, buChongFu): | |||||
pass | |||||
auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2}, | |||||
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2}) | |||||
# 可以定制signature_fn | |||||
def fn1(**kwargs): | |||||
kwargs.pop('x') | |||||
kwargs.pop('y') | |||||
assert len(kwargs)==0 | |||||
def fn(x, y): | |||||
pass | |||||
x = {'x': 3, 'y': 4} | |||||
r = auto_param_call(fn1, x, signature_fn=fn) | |||||
# 没提供的时候报错 | |||||
def fn(meiti1, meiti2, tigong): | |||||
pass | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(fn, {'tigong':1}) | |||||
assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0] | |||||
# 默认值替换 | |||||
def fn(x, y=100): | |||||
return x + y | |||||
r = auto_param_call(fn, {'x': 10, 'y': 20}) | |||||
assert r==30 | |||||
assert auto_param_call(fn, {'x': 10, 'z': 20})==110 | |||||
# 测试mapping的使用 | |||||
def fn(x, y=100): | |||||
return x + y | |||||
r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'}) | |||||
assert r==30 | |||||
# 测试不需要任何参数 | |||||
def fn(): | |||||
return 1 | |||||
assert 1 == auto_param_call(fn, {'x':1}) | |||||
# 测试调用类的方法没问题 | |||||
assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1}) | |||||
assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'}) | |||||
def test_msg(self): | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(self.call_this, {'x':1}) | |||||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(call_this_for_auto_param_call, {'x':1}) | |||||
assert __file__ in exc_info.value.args[0] | |||||
assert 'call_this_for_auto_param_call' in exc_info.value.args[0] | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(self.call_this_two, {'x':1}) | |||||
assert __file__ in exc_info.value.args[0] | |||||
with pytest.raises(BaseException) as exc_info: | |||||
auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this) | |||||
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息 | |||||
def call_this(self, x, y): | |||||
return x + y | |||||
def call_this_two(self, x, y, z=pytest, **kwargs): | |||||
return x + y | |||||
def test_metric_auto_param_call(self): | |||||
metric = AutoParamCallMetric() | |||||
with pytest.raises(BaseException): | |||||
auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__) | |||||
class AutoParamCallMetric(Metric): | |||||
def update(self, x): | |||||
pass | |||||
def call_this_for_auto_param_call(x, y): | |||||
return x + y | |||||
class TestCheckNumberOfParameters: | |||||
def test_validate_every(self): | |||||
def validate_every(trainer): | |||||
pass | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
# 无默认值,多了报错 | |||||
def validate_every(trainer, other): | |||||
pass | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
assert "2 parameters" in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | |||||
# 有默认值ok | |||||
def validate_every(trainer, other=1): | |||||
pass | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
# 参数多了 | |||||
def validate_every(trainer): | |||||
pass | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||||
assert "accepts 1 parameters" in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | |||||
# 使用partial | |||||
def validate_every(trainer, other): | |||||
pass | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||||
assert 'accepts 2 parameters' in exc_info.value.args[0] | |||||
print(exc_info.value.args[0]) | |||||
# 如果存在 *args 或 *kwargs 不报错多的 | |||||
def validate_every(trainer, *args): | |||||
pass | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more']) | |||||
def validate_every(trainer, **kwargs): | |||||
pass | |||||
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||||
# class 的方法删掉self | |||||
class InnerClass: | |||||
def demo(self, x): | |||||
pass | |||||
def no_param(self): | |||||
pass | |||||
def param_kwargs(self, **kwargs): | |||||
pass | |||||
inner = InnerClass() | |||||
with pytest.raises(RuntimeError) as exc_info: | |||||
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||||
assert 'accepts 1 parameters' in exc_info.value.args[0] | |||||
_check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||||
def test_get_fun_msg(): | |||||
def demo(x): | |||||
pass | |||||
print(_get_fun_msg(_get_fun_msg)) |
@@ -2,37 +2,19 @@ import os | |||||
import sys | import sys | ||||
import __main__ | import __main__ | ||||
from functools import wraps | from functools import wraps | ||||
import inspect | |||||
from inspect import ismethod | from inspect import ismethod | ||||
import functools | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from io import StringIO | from io import StringIO | ||||
import time | import time | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.utils.utils import get_class_that_defined_method | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | ||||
from fastNLP.core.drivers.utils import distributed_open_proc | from fastNLP.core.drivers.utils import distributed_open_proc | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
def get_class_that_defined_method(meth): | |||||
if isinstance(meth, functools.partial): | |||||
return get_class_that_defined_method(meth.func) | |||||
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): | |||||
for cls in inspect.getmro(meth.__self__.__class__): | |||||
if meth.__name__ in cls.__dict__: | |||||
return cls | |||||
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing | |||||
if inspect.isfunction(meth): | |||||
cls = getattr(inspect.getmodule(meth), | |||||
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||||
None) | |||||
if isinstance(cls, type): | |||||
return cls | |||||
return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||||
def recover_logger(fn): | def recover_logger(fn): | ||||
@wraps(fn) | @wraps(fn) | ||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||