@@ -4,8 +4,7 @@ __all__ = [ | |||||
'EventsList', | 'EventsList', | ||||
'Filter', | 'Filter', | ||||
'CallbackManager', | 'CallbackManager', | ||||
'ModelCheckpointCallback', | |||||
'TrainerCheckpointCallback', | |||||
'CheckpointCallback', | |||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
@@ -13,18 +12,21 @@ __all__ = [ | |||||
'LoadBestModelCallback', | 'LoadBestModelCallback', | ||||
"EarlyStopCallback", | "EarlyStopCallback", | ||||
'MoreEvaluateCallback', | |||||
"TorchWarmupCallback", | "TorchWarmupCallback", | ||||
"TorchGradClipCallback" | |||||
"TorchGradClipCallback", | |||||
] | ] | ||||
from .callback import Callback | from .callback import Callback | ||||
from .callback_events import EventsList, Events, Filter | from .callback_events import EventsList, Events, Filter | ||||
from .callback_manager import CallbackManager | from .callback_manager import CallbackManager | ||||
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||||
from .checkpoint_callback import CheckpointCallback | |||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
from .lr_scheduler_callback import LRSchedCallback | from .lr_scheduler_callback import LRSchedCallback | ||||
from .load_best_model_callback import LoadBestModelCallback | from .load_best_model_callback import LoadBestModelCallback | ||||
from .early_stop_callback import EarlyStopCallback | from .early_stop_callback import EarlyStopCallback | ||||
from .torch_callbacks import * | from .torch_callbacks import * | ||||
from .more_evaluate_callback import MoreEvaluateCallback | |||||
@@ -236,7 +236,7 @@ class Callback: | |||||
结束 validate 时调用,并把 validate 的结果传入。 | 结束 validate 时调用,并把 validate 的结果传入。 | ||||
:param trainer: | :param trainer: | ||||
:param results: | |||||
:param results: Evaluate 的结果,一般是个 dict 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
@@ -250,6 +250,15 @@ class Callback: | |||||
""" | """ | ||||
return self.__class__.__name__ | return self.__class__.__name__ | ||||
@property | |||||
def need_reproducible_sampler(self) -> bool: | |||||
""" | |||||
当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。 | |||||
:return: | |||||
""" | |||||
return False | |||||
class _CallbackWrapper(Callback): | class _CallbackWrapper(Callback): | ||||
""" | """ | ||||
@@ -8,7 +8,6 @@ __all__ = [ | |||||
from .callback_events import Events | from .callback_events import Events | ||||
from .callback import Callback | from .callback import Callback | ||||
from .checkpoint_callback import TrainerCheckpointCallback | |||||
from .progress_callback import ProgressCallback, choose_progress_callback | from .progress_callback import ProgressCallback, choose_progress_callback | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -45,7 +44,7 @@ class CallbackManager: | |||||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | :param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | ||||
""" | """ | ||||
self._has_trainer_checkpoint = False | |||||
self._need_reproducible_sampler = False | |||||
_has_progress_callback = False | _has_progress_callback = False | ||||
_callbacks = [] | _callbacks = [] | ||||
@@ -98,8 +97,7 @@ class CallbackManager: | |||||
:return: | :return: | ||||
""" | """ | ||||
for each_callback in self.class_callbacks: | for each_callback in self.class_callbacks: | ||||
if isinstance(each_callback, TrainerCheckpointCallback): | |||||
self._has_trainer_checkpoint = True | |||||
self._need_reproducible_sampler |= each_callback.need_reproducible_sampler | |||||
self.dissect_one_callback(each_callback) | self.dissect_one_callback(each_callback) | ||||
def dissect_one_callback(self, callback: Callback): | def dissect_one_callback(self, callback: Callback): | ||||
@@ -211,7 +209,7 @@ class CallbackManager: | |||||
@property | @property | ||||
def has_trainer_checkpoint(self) -> bool: | def has_trainer_checkpoint(self) -> bool: | ||||
return self._has_trainer_checkpoint | |||||
return self._need_reproducible_sampler | |||||
@_transfer | @_transfer | ||||
def on_after_trainer_initialized(self, trainer): | def on_after_trainer_initialized(self, trainer): | ||||
@@ -1,339 +1,151 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'ModelCheckpointCallback', | |||||
'TrainerCheckpointCallback' | |||||
'CheckpointCallback' | |||||
] | ] | ||||
import os | |||||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | |||||
from typing import Union, Optional, Callable, Dict, Sequence | |||||
from pathlib import Path | from pathlib import Path | ||||
import sys | import sys | ||||
from copy import deepcopy | |||||
import fastNLP | |||||
from .has_monitor_callback import HasMonitorCallback | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK | |||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||||
from .topk_saver import TopkSaver | |||||
from .callback import Callback | |||||
class CheckpointCallback(HasMonitorCallback): | |||||
def __init__( | |||||
self, | |||||
monitor:Optional[Union[str, Callable]]=None, | |||||
save_folder: Optional[Union[str, Path]] = None, | |||||
save_every_n_epochs: Optional[int] = None, | |||||
save_every_n_batches: Optional[int] = None, | |||||
save_last: bool = False, | |||||
save_topk: Optional[int] = None, | |||||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | |||||
larger_better: bool = True, | |||||
only_state_dict: bool = True, | |||||
model_save_fn: Optional[Callable] = None, | |||||
**kwargs, | |||||
): | |||||
class CheckpointCallback(Callback): | |||||
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | |||||
every_n_batches: Optional[int] = None, last: bool = False, | |||||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0, | |||||
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | |||||
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | |||||
save_evaluate_results=True, **kwargs): | |||||
""" | """ | ||||
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 | |||||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 | |||||
- {save_object}-last/ # 最后一个 epoch 的保存 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param save_topk: 保存 monitor 结果 topK 个。 | |||||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param every_n_batches: 多少个 batch 保存一次。 | |||||
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param topk: 保存 monitor 结果 topK 个。 | |||||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | :param larger_better: monitor 的值是否时越大越好。 | ||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | ||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | ||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | ||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param kwargs: | :param kwargs: | ||||
""" | """ | ||||
super().__init__(monitor=monitor, larger_better=larger_better, | |||||
must_have_monitor=save_topk is not None) | |||||
if save_folder is None: | |||||
super().__init__() | |||||
if folder is None: | |||||
logger.warning( | logger.warning( | ||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||||
save_folder = Path.cwd() | |||||
save_folder = Path(save_folder) | |||||
if not save_folder.exists(): | |||||
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | |||||
elif save_folder.is_file(): | |||||
raise ValueError("Parameter `save_folder` should be a directory instead of a file.") | |||||
if save_every_n_epochs is not None: | |||||
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: | |||||
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") | |||||
"Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||||
folder = Path.cwd() | |||||
folder = Path(folder) | |||||
if not folder.exists(): | |||||
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") | |||||
elif folder.is_file(): | |||||
raise ValueError("Parameter `folder` should be a directory instead of a file.") | |||||
if every_n_epochs is not None: | |||||
if not isinstance(every_n_epochs, int) or every_n_epochs < 1: | |||||
raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.") | |||||
else: | else: | ||||
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||||
every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||||
if save_every_n_batches is not None: | |||||
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: | |||||
raise ValueError( | |||||
"parameter save_every_n_batches should be an int and greater than or equal to 1.") | |||||
if every_n_batches is not None: | |||||
if not isinstance(every_n_batches, int) or every_n_batches < 1: | |||||
raise ValueError("Parameter `every_n_batches` should be an int and greater than or equal to 1.") | |||||
else: | else: | ||||
save_every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||||
every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||||
if save_topk is not None: | |||||
if not isinstance(save_topk, int) or save_topk < 1: | |||||
raise ValueError("parameter save_topk should be an int and greater than or equal to 1.") | |||||
if topk is not None: | |||||
if not isinstance(topk, int): | |||||
raise ValueError("Parameter `topk` should be an int.") | |||||
else: | |||||
topk = 0 | |||||
if save_on_exception is not None: | |||||
if not isinstance(save_on_exception, Sequence): | |||||
save_on_exception = [save_on_exception] | |||||
if on_exceptions is not None: | |||||
if not isinstance(on_exceptions, Sequence): | |||||
on_exceptions = [on_exceptions] | |||||
for exception in save_on_exception: | |||||
for exception in on_exceptions: | |||||
if not issubclass(exception, BaseException): | if not issubclass(exception, BaseException): | ||||
raise TypeError("Each exception in parameter `save_on_exception` can only be " | |||||
raise TypeError("Each exception in parameter `on_exception` can only be " | |||||
"`BaseException` type.") | "`BaseException` type.") | ||||
else: | else: | ||||
save_on_exception = [] | |||||
on_exceptions = [] | |||||
self.save_folder = save_folder | |||||
self.save_every_n_epochs = save_every_n_epochs | |||||
self.save_every_n_batches = save_every_n_batches | |||||
self.save_last = save_last | |||||
self.save_topk = save_topk | |||||
self.only_state_dict = only_state_dict | |||||
self.model_save_fn = model_save_fn | |||||
self.save_on_exception = save_on_exception | |||||
self.kwargs = kwargs | |||||
self.topk_saver = TopkSaver(topk, monitor, larger_better, folder, only_state_dict, | |||||
model_save_fn, save_evaluate_results, | |||||
save_object, **kwargs) | |||||
self.topk = topk | |||||
self.save_object = save_object | |||||
# 这些参数是专门留给 topk 模式专门使用的; | |||||
self._topk_model = {} | |||||
self._topn = 0 # 表示目前已经保存了几个最好的模型; | |||||
self.every_n_epochs = every_n_epochs | |||||
self.every_n_batches = every_n_batches | |||||
self.last = last | |||||
self.exceptions = on_exceptions | |||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||||
# 该 folder 只在保存真的要发生的时候再创建。 | |||||
@property | |||||
def need_reproducible_sampler(self) -> bool: | |||||
return self.save_object == 'trainer' | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if self.save_topk is not None: | |||||
super().on_after_trainer_initialized(trainer, driver) | |||||
if self.save_topk is not None and trainer.evaluator is None: | |||||
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | |||||
if self.topk_saver.topk_queue: # 需要设置 monitor | |||||
if self.topk_saver.monitor is None: | |||||
self.topk_saver.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
if self.topk_saver.topk_queue and trainer.evaluator is None: | |||||
logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") | |||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
self._save_topk(trainer, results) | |||||
# 如果发生了保存,则返回的 folder 不为 None | |||||
folder = self.topk_saver.save_topk(trainer, results) | |||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | ||||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | |||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' | |||||
self.save(trainer, folder_name=folder_name) | |||||
if self.save_last: | |||||
folder_name = f'{self.folder_prefix}-last' | |||||
self.save(trainer, folder_name=folder_name) | |||||
if trainer.cur_epoch_idx % self.every_n_epochs == 0: | |||||
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}' | |||||
self.topk_saver.save(trainer, folder_name=folder_name) | |||||
if self.last: | |||||
folder_name = f'{self.save_object}-last' | |||||
self.topk_saver.save(trainer, folder_name=folder_name) | |||||
def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
if trainer.global_forward_batches % self.save_every_n_batches == 0: | |||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||||
self.save(trainer, folder_name=folder_name) | |||||
if trainer.global_forward_batches % self.every_n_batches == 0: | |||||
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||||
self.topk_saver.save(trainer, folder_name=folder_name) | |||||
def on_exception(self, trainer, exception: BaseException): | def on_exception(self, trainer, exception: BaseException): | ||||
if exception.__class__ in self.save_on_exception: | |||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||||
f'exception_{exception.__class__.__name__}' | |||||
self.save(trainer=trainer, folder_name=folder_name) | |||||
if exception.__class__ in self.exceptions: | |||||
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||||
f'exception_{exception.__class__.__name__}' | |||||
self.topk_saver.save(trainer, folder_name=folder_name) | |||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | """ | ||||
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | |||||
topk_model的状态 | |||||
_real_monitor的值 | |||||
保存状态,以便之后可以继续使用 | |||||
""" | """ | ||||
states = {} | states = {} | ||||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||||
states['_topk_model'] = deepcopy(self._topk_model) | |||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||||
if isinstance(self._real_monitor, str): | |||||
states['_real_monitor'] = self._real_monitor | |||||
states['topk_saver'] = self.topk_saver.state_dict() | |||||
return states | return states | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
timestamp_path = states['timestamp_path'] | |||||
if not os.path.exists(timestamp_path): | |||||
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " | |||||
f" {self.timestamp_path.absolute()}.") | |||||
else: | |||||
logger.info(f"Resume to checkpoint in path: {timestamp_path}.") | |||||
self.timestamp_path = Path(timestamp_path) | |||||
_topk_model = states['_topk_model'] | |||||
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | |||||
if save_topk is not None and self.save_topk is not None: | |||||
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | |||||
f"as {save_topk}." | |||||
self._topk_model.update(self._topk_model) | |||||
self._real_monitor = states["_real_monitor"] | |||||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||||
""" | |||||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | |||||
:param trainer: | |||||
:param results: | |||||
:return: | |||||
""" | |||||
if self.save_topk is not None: | |||||
monitor_value = self.get_monitor_value(results=results) | |||||
if monitor_value is None: | |||||
return | |||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||||
f"-{self._real_monitor}_{monitor_value}" | |||||
_should_save = False | |||||
if self._topn < self.save_topk: | |||||
self._topk_model[folder_name] = monitor_value | |||||
self._topn += 1 | |||||
_should_save = True | |||||
else: | |||||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||||
key=lambda x: self._topk_model[x]) | |||||
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||||
self._topk_model[folder_name] = monitor_value | |||||
_should_save = True | |||||
self._topk_model.pop(_least_valuable_model) | |||||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | |||||
assert len(self._topk_model) == self.save_topk == self._topn | |||||
if _should_save: | |||||
self.save(trainer, folder_name=folder_name) | |||||
def save(self, trainer, folder_name): | |||||
""" | |||||
执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 | |||||
:param trainer: | |||||
:param folder_name: | |||||
:return: | |||||
""" | |||||
folder = self.timestamp_path.joinpath(folder_name) | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建 | |||||
synchronize_mkdir(folder) | |||||
_fn = getattr(trainer, self.save_fn_name) | |||||
_fn( | |||||
folder=folder, | |||||
only_state_dict=self.only_state_dict, | |||||
model_save_fn=self.model_save_fn, | |||||
**self.kwargs | |||||
) | |||||
@property | |||||
def folder_prefix(self): | |||||
raise NotImplementedError("The `folder_prefix` is not specified") | |||||
@property | |||||
def save_fn_name(self): | |||||
raise NotImplementedError("The `save_fn_name` is not specified.") | |||||
class ModelCheckpointCallback(CheckpointCallback): | |||||
""" | |||||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
- save_folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||||
- model-last/ # 最后一个 epoch 的保存 | |||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
topk_saver_states = states['topk_saver'] | |||||
self.topk_saver.load_state_dict(topk_saver_states) | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param save_topk: 保存 monitor 结果 topK 个。 | |||||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: | |||||
""" | |||||
@property | |||||
def save_fn_name(self): | |||||
""" | |||||
调用 Trainer 中的哪个函数。 | |||||
:return: | |||||
""" | |||||
return 'save_model' | |||||
@property | |||||
def callback_name(self): | |||||
""" | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||||
:return: | |||||
""" | |||||
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | |||||
def folder_prefix(self): | |||||
return 'model' | |||||
class TrainerCheckpointCallback(CheckpointCallback): | |||||
""" | |||||
保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
- save_folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||||
- trainer-last/ # 最后一个 epoch 的保存 | |||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | |||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
:param save_topk: 保存 monitor 结果 topK 个。 | |||||
:param save_on_exception: 在出异常信息时,是否保存。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param kwargs: | |||||
""" | |||||
@property | |||||
def save_fn_name(self): | |||||
""" | |||||
调用 Trainer 中的哪个函数。 | |||||
:return: | |||||
""" | |||||
return 'save' | |||||
@property | |||||
def callback_name(self): | |||||
""" | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||||
:return: | |||||
""" | |||||
return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
@property | |||||
def folder_prefix(self): | |||||
return 'trainer' |
@@ -1,10 +1,12 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
'ExecuteOnceBetterMonitor' | |||||
'ExecuteOnceBetterMonitor', | |||||
'MonitorUtility' | |||||
] | ] | ||||
from typing import Dict, Union, Any | from typing import Dict, Union, Any | ||||
from abc import ABC | from abc import ABC | ||||
import functools | |||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
from fastNLP.core.callbacks import Callback | from fastNLP.core.callbacks import Callback | ||||
@@ -27,21 +29,13 @@ class CanItemDataType(ABC): | |||||
return NotImplemented | return NotImplemented | ||||
class MonitorUtility: | |||||
""" | |||||
计算 monitor 的相关函数 | |||||
class HasMonitorCallback(Callback): | |||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
""" | |||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
""" | |||||
""" | |||||
def __init__(self, monitor, larger_better): | |||||
self.set_monitor(monitor, larger_better) | self.set_monitor(monitor, larger_better) | ||||
self.must_have_moinitor = must_have_monitor | |||||
def set_monitor(self, monitor, larger_better): | def set_monitor(self, monitor, larger_better): | ||||
if callable(monitor): # 检查是否能够接受一个参数 | if callable(monitor): # 检查是否能够接受一个参数 | ||||
@@ -57,26 +51,14 @@ class HasMonitorCallback(Callback): | |||||
self.monitor_value = float('inf') | self.monitor_value = float('inf') | ||||
self._real_monitor = self.monitor | self._real_monitor = self.monitor | ||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
def itemize_results(self, results): | |||||
""" | """ | ||||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
将结果中有 .item() 方法的都调用一下,使得可以结果可以保存 | |||||
:param trainer: | |||||
:param driver: | |||||
:param results: | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.monitor is None and trainer.monitor is not None: | |||||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
if self.must_have_moinitor and self.monitor is None: | |||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
f"You can set it in the initialization or through Trainer.") | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
# 主要核对一下 monitor 是否存在。 | |||||
if self.monitor is not None: | |||||
self.get_monitor_value(results=sanity_check_res) | |||||
return apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
def get_monitor_value(self, results:Dict)->Union[float, None]: | def get_monitor_value(self, results:Dict)->Union[float, None]: | ||||
""" | """ | ||||
@@ -85,10 +67,10 @@ class HasMonitorCallback(Callback): | |||||
:param results: | :param results: | ||||
:return: 如果为 None ,表明此次没有找到合适的monitor | :return: 如果为 None ,表明此次没有找到合适的monitor | ||||
""" | """ | ||||
if len(results)==0: | |||||
if len(results) == 0 or self.monitor is None: | |||||
return None | return None | ||||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | # 保证所有的 tensor 都被转换为了 python 特定的类型 | ||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
results = self.itemize_results(results) | |||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | ||||
real_monitor=self._real_monitor, | real_monitor=self._real_monitor, | ||||
res=results) | res=results) | ||||
@@ -97,7 +79,7 @@ class HasMonitorCallback(Callback): | |||||
# 第一次运行 | # 第一次运行 | ||||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | ||||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | ||||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
# 检测到此次和上次不同。 | # 检测到此次和上次不同。 | ||||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | ||||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | ||||
@@ -165,7 +147,10 @@ class HasMonitorCallback(Callback): | |||||
""" | """ | ||||
if callable(self.monitor): | if callable(self.monitor): | ||||
try: | try: | ||||
monitor_name = self.monitor.__qualname__ | |||||
monitor = self.monitor | |||||
while isinstance(monitor, functools.partial): | |||||
monitor = monitor.func | |||||
monitor_name = monitor.__qualname__ | |||||
except: | except: | ||||
monitor_name = self.monitor.__name__ | monitor_name = self.monitor.__name__ | ||||
elif self.monitor is None: | elif self.monitor is None: | ||||
@@ -176,6 +161,46 @@ class HasMonitorCallback(Callback): | |||||
return monitor_name | return monitor_name | ||||
class HasMonitorCallback(MonitorUtility, Callback): | |||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
""" | |||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
""" | |||||
super().__init__(monitor, larger_better) | |||||
self.must_have_monitor = must_have_monitor | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
""" | |||||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
:param trainer: | |||||
:param driver: | |||||
:return: | |||||
""" | |||||
if self.monitor is None and trainer.monitor is not None: | |||||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
if self.must_have_monitor and self.monitor is None: | |||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
f"You can set it in the initialization or through Trainer.") | |||||
if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None: | |||||
raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.__class__.__name__}" | |||||
f" need to watch the monitor:`{self.monitor_name}`.") | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
# 主要核对一下 monitor 是否存在。 | |||||
if self.monitor is not None: | |||||
self.get_monitor_value(results=sanity_check_res) | |||||
class ExecuteOnceBetterMonitor(HasMonitorCallback): | class ExecuteOnceBetterMonitor(HasMonitorCallback): | ||||
def __init__(self, monitor, larger_better, execute_fn): | def __init__(self, monitor, larger_better, execute_fn): | ||||
""" | """ | ||||
@@ -183,13 +208,13 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param larger_better: monitor 是否时越大越好 | :param larger_better: monitor 是否时越大越好 | ||||
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | ||||
""" | """ | ||||
super().__init__(monitor, larger_better, must_have_monitor=True) | super().__init__(monitor, larger_better, must_have_monitor=True) | ||||
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | ||||
self.execute_fn = execute_fn() | |||||
self.execute_fn = execute_fn | |||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
if self.is_better_results(results): | if self.is_better_results(results): |
@@ -23,7 +23,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -72,7 +72,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | ||||
except NotImplementedError: | except NotImplementedError: | ||||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | ||||
f"save best model when launch using script.") | |||||
f"save best model when launch using module.") | |||||
super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
@@ -87,7 +87,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
logger.info(f"Loading best model with {self._real_monitor}: {self.monitor_value}...") | |||||
logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") | |||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
model_load_fn=self.model_load_fn) | model_load_fn=self.model_load_fn) | ||||
@@ -0,0 +1,174 @@ | |||||
__all__ = [ | |||||
'MoreEvaluateCallback' | |||||
] | |||||
import os | |||||
from typing import Union, Callable, Optional, Dict | |||||
from fastNLP.core.log import logger | |||||
from .has_monitor_callback import HasMonitorCallback | |||||
from .topk_saver import TopkSaver | |||||
class MoreEvaluateCallback(HasMonitorCallback): | |||||
def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | |||||
watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | |||||
evaluate_fn=None, num_eval_sanity_batch=2, | |||||
topk=0, topk_monitor=None, topk_larger_better=True, | |||||
folder=None, only_state_dict=True, save_object='model', model_save_fn=None, | |||||
save_evaluate_results=True, save_kwargs=None, | |||||
**kwargs): | |||||
""" | |||||
当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||||
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||||
如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存 | |||||
- folder/ | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
:param dataloaders: 需要评估的数据 | |||||
:param metrics: 使用的 metrics 。 | |||||
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch | |||||
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||||
一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | |||||
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | |||||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | |||||
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 | |||||
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | |||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||||
:param watch_monitor_larger_better: watch_monitor 是否越大越好。 | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
:param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 | |||||
:param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 | |||||
evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 | |||||
:param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | |||||
:param topk_larger_better: topk_monitor 的值是否时越大越好。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
:param save_kwargs: dict。更多的保存相关的参数。 | |||||
:param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。请特别留意 evaluate_fn 的设置。 | |||||
""" | |||||
super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, | |||||
must_have_monitor=False) | |||||
if watch_monitor is None and evaluate_every is None: | |||||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | |||||
if watch_monitor is not None and evaluate_every is not None: | |||||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | |||||
self.watch_monitor = watch_monitor | |||||
if topk_monitor is not None and topk == 0: | |||||
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | |||||
if topk != 0 and topk_monitor is None: | |||||
raise RuntimeError("`topk` is set, but `topk_monitor` is None.") | |||||
assert save_object in ['trainer', 'model'] | |||||
self.dataloaders = dataloaders | |||||
self.metrics = metrics | |||||
self.evaluate_every = evaluate_every | |||||
self.evaluate_fn = evaluate_fn | |||||
self.num_eval_sanity_batch = num_eval_sanity_batch | |||||
if save_kwargs is None: | |||||
save_kwargs = {} | |||||
self.topk_saver = TopkSaver(topk=topk, monitor=topk_monitor, larger_better=topk_larger_better, | |||||
folder=folder, only_state_dict=only_state_dict, | |||||
model_save_fn=model_save_fn, save_evaluate_results=save_evaluate_results, | |||||
save_object=save_object, **save_kwargs) | |||||
self.kwargs = kwargs | |||||
@property | |||||
def need_reproducible_sampler(self) -> bool: | |||||
return self.topk_saver.save_object == 'trainer' | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
# 如果是需要 watch 的,不能没有 evaluator | |||||
if self.watch_monitor is not None: | |||||
assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \ | |||||
f"evaluate_dataloaders is provided in Trainer." | |||||
if trainer.evaluate_fn is self.evaluate_fn: | |||||
logger.warning_once("The `evaluate_fn` is the same as in Trainer, there seems no need to use " | |||||
"`MoreEvaluateCallback`.") | |||||
# 初始化 evaluator , 同时避免调用 super 对 monitor 赋值 | |||||
kwargs = { | |||||
'model': self.kwargs.get('model', trainer.model), | |||||
'dataloaders': self.dataloaders, | |||||
'metrics': self.metrics, | |||||
'driver': self.kwargs.get('driver', trainer.driver), | |||||
'device': self.kwargs.get('device', trainer.device), | |||||
'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn), | |||||
'evaluate_fn': self.evaluate_fn, | |||||
'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), | |||||
'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), | |||||
'fp16': self.kwargs.get('fp16', trainer.fp16), | |||||
'use_dist_sampler': self.kwargs.get('use_dist_sampler', | |||||
trainer.kwargs.get('eval_use_dist_sampler', None)), | |||||
'progress_bar': self.kwargs.get('progress_bar', trainer.kwargs.get('progress_bar', 'auto')), | |||||
'verbose': self.kwargs.get('verbose', 1) | |||||
} | |||||
for key, value in self.kwargs.items(): | |||||
if key not in kwargs: | |||||
kwargs[key] = value | |||||
from fastNLP.core.controllers.evaluator import Evaluator | |||||
self.evaluator = Evaluator(**kwargs) | |||||
if self.num_eval_sanity_batch>0: | |||||
results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) | |||||
self.topk_saver.get_monitor_value(results) | |||||
def on_validate_end(self, trainer, results): | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
results = self.evaluator.run() | |||||
self.topk_saver.save_topk(trainer, results) | |||||
def on_train_epoch_end(self, trainer): | |||||
if self.watch_monitor is not None: | |||||
return | |||||
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | |||||
validate_every = -self.evaluate_every | |||||
if trainer.cur_epoch_idx % validate_every == 0: | |||||
results = self.evaluator.run() | |||||
self.topk_saver.save_topk(trainer, results) | |||||
def on_train_batch_end(self, trainer): | |||||
if self.watch_monitor is not None: | |||||
return | |||||
if callable(self.evaluate_every): | |||||
if self.evaluate_every(self): | |||||
results = self.evaluator.run() | |||||
self.topk_saver.save_topk(trainer, results) | |||||
elif self.evaluate_every > 0 and trainer.global_forward_batches % self.evaluate_every == 0: | |||||
results = self.evaluator.run() | |||||
self.topk_saver.save_topk(trainer, results) | |||||
def on_save_checkpoint(self, trainer) -> Dict: | |||||
states = {'topk_saver': self.topk_saver.state_dict()} | |||||
if isinstance(self._real_monitor, str): | |||||
states['_real_monitor'] = self._real_monitor | |||||
states['monitor_value'] = self.monitor_value | |||||
return states | |||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||||
topk_saver_states = states['topk_saver'] | |||||
self.topk_saver.load_state_dict(topk_saver_states) | |||||
if '_real_monitor' in states: | |||||
self._real_monitor = states["_real_monitor"] | |||||
self.monitor_value = states['monitor_value'] | |||||
@property | |||||
def callback_name(self): | |||||
metric_names = '+'.join(sorted(self.metrics.keys())) | |||||
return f'more_evaluate_callback#metric_name-{metric_names}#monitor-{self.monitor_name}#topk_saver:{self.topk_saver}' | |||||
@@ -9,7 +9,6 @@ __all__ = [ | |||||
] | ] | ||||
from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -42,7 +41,8 @@ class RichCallback(ProgressCallback): | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | ||||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | ||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
相关的 monitor 值请返回 None 。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | :param larger_better: 是否是 monitor 的结果越大越好。 | ||||
:param format_json: 是否格式化 json 再打印 | :param format_json: 是否格式化 json 再打印 | ||||
""" | """ | ||||
@@ -135,7 +135,8 @@ class RawTextCallback(ProgressCallback): | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | ||||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | ||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
相关的 monitor 值请返回 None 。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -0,0 +1,246 @@ | |||||
import json | |||||
import os | |||||
from copy import deepcopy | |||||
from pathlib import Path | |||||
from typing import Optional, Dict, Tuple | |||||
from fastNLP.core.utils import rank_zero_rm | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||||
from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME | |||||
from .has_monitor_callback import MonitorUtility | |||||
class Saver: | |||||
def __init__(self, folder, only_state_dict, model_save_fn, **kwargs): | |||||
""" | |||||
执行保存的对象。保存的文件组织结构为 | |||||
- folder # 当前初始化的参数 | |||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
- folder_name # 由 save() 调用时传入。 | |||||
:param folder: | |||||
:param only_state_dict: | |||||
:param model_save_fn: | |||||
:param kwargs: | |||||
""" | |||||
if folder is None: | |||||
logger.warning( | |||||
"Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||||
folder = Path.cwd() | |||||
folder = Path(folder) | |||||
if not folder.exists(): | |||||
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") | |||||
elif folder.is_file(): | |||||
raise ValueError("Parameter `folder` should be a directory instead of a file.") | |||||
self.folder = folder | |||||
self.only_state_dict = only_state_dict | |||||
self.model_save_fn = model_save_fn | |||||
self.kwargs = kwargs | |||||
self.eval_results = kwargs.get('eval_results', True) | |||||
self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||||
@rank_zero_call | |||||
def save(self, save_fn, folder_name): | |||||
""" | |||||
执行保存的函数,将数据保存在 folder/timestamp/folder_name 下。其中 folder 为用户在初始化指定, | |||||
timestamp 为当前脚本的启动时间。 | |||||
:param save_fn: 调用的保存函数,应该可接受参数 folder:str, only_state_dict: bool, model_save_fn: callable, kwargs | |||||
:param folder_name: 保存的 folder 名称,将被创建。 | |||||
:return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。 | |||||
""" | |||||
folder = self.timestamp_path.joinpath(folder_name) | |||||
folder.mkdir(parents=True, exist_ok=True) | |||||
save_fn( | |||||
folder=folder, | |||||
only_state_dict=self.only_state_dict, | |||||
model_save_fn=self.model_save_fn, | |||||
**self.kwargs | |||||
) | |||||
return str(os.path.abspath(folder)) | |||||
@rank_zero_call | |||||
def save_json(self, results, path): | |||||
""" | |||||
以 json 格式保存 results 到 path 中 | |||||
:param results: | |||||
:param path: | |||||
:return: | |||||
""" | |||||
with open(path, 'w', encoding='utf8') as f: | |||||
json.dump(results, f, indent=2) | |||||
@rank_zero_call | |||||
def rm(self, folder_name): | |||||
""" | |||||
移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。 | |||||
:param folder_name: | |||||
:return: | |||||
""" | |||||
folder = self.timestamp_path.joinpath(folder_name) | |||||
rank_zero_rm(folder) | |||||
def state_dict(self): | |||||
states = { | |||||
'timestamp_path': str(self.timestamp_path), | |||||
} | |||||
return states | |||||
def load_state_dict(self, states): | |||||
timestamp_path = states['timestamp_path'] | |||||
if not os.path.exists(timestamp_path): | |||||
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to " | |||||
f" {self.timestamp_path.absolute()}.") | |||||
else: | |||||
logger.info(f"Resume to save checkpoint in path: {timestamp_path}.") | |||||
self.timestamp_path = Path(timestamp_path) | |||||
def __str__(self): | |||||
return 'saver' # saver是无状态的,不需要有特定名字 | |||||
class TopkQueue: | |||||
def __init__(self, topk): | |||||
""" | |||||
用于维护处于 topk 的 key, value 对。 | |||||
:param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。 | |||||
""" | |||||
assert isinstance(topk, int) | |||||
self.topk = topk | |||||
self.topk_dict = {} # 其中 key 为保存的 | |||||
def push(self, key, value) -> Optional[Tuple[str, float]]: | |||||
""" | |||||
将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给 | |||||
挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 | |||||
推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。 | |||||
:param str key: | |||||
:param float value: 如果为 None, 则不做任何操作。 | |||||
:return: (1)返回输入的 (key, value) ,说明不满足 topk; (2) 返回(None, None),说明满足 topk 且没有被挤出过去的记录; (3) | |||||
返回非输入的 (key, value) , 说明输入满足 topk,且挤出了之前的记录。 | |||||
""" | |||||
if value is None: | |||||
return key, value | |||||
if self.topk < 0: | |||||
return None, None | |||||
if self.topk == 0: | |||||
return key, value | |||||
if len(self.topk_dict)<self.topk: | |||||
self.topk_dict[key] = value | |||||
return None, None | |||||
min_key = min(self.topk_dict, key=lambda x:self.topk_dict[x]) | |||||
if self.topk_dict[min_key] > value: | |||||
return key, value | |||||
else: | |||||
min_value = self.topk_dict.pop(min_key) | |||||
self.topk_dict[key] = value | |||||
return min_key, min_value | |||||
def state_dict(self): | |||||
return deepcopy(self.topk_dict) | |||||
def load_state_dict(self, states): | |||||
self.topk_dict.update(states) | |||||
def __str__(self): | |||||
return f'topk-{self.topk}' | |||||
def __bool__(self): | |||||
# 仅当 topk 为 0 时,表明该 topk_queue 无意义。 | |||||
return self.topk != 0 | |||||
class TopkSaver(MonitorUtility, Saver): | |||||
def __init__(self, topk, monitor, larger_better, folder, only_state_dict, | |||||
model_save_fn, save_evaluate_results, | |||||
save_object, **kwargs): | |||||
""" | |||||
用来保存识别 tokp 模型并保存。 | |||||
:param topk: | |||||
:param monitor: | |||||
:param larger_better: | |||||
:param folder: | |||||
:param only_state_dict: | |||||
:param model_save_fn: | |||||
:param save_evaluate_results: | |||||
:param save_object: | |||||
:param kwargs: | |||||
""" | |||||
MonitorUtility.__init__(self, monitor, larger_better) | |||||
Saver.__init__(self, folder, only_state_dict, model_save_fn, **kwargs) | |||||
if monitor is not None and topk == 0: | |||||
raise RuntimeError("`monitor` is set, but `topk` is 0.") | |||||
if topk != 0 and monitor is None: | |||||
raise RuntimeError("`topk` is set, but `monitor` is None.") | |||||
assert save_object in ['trainer', 'model'] | |||||
self.saver = Saver(folder, only_state_dict, model_save_fn, **kwargs) | |||||
self.topk_queue = TopkQueue(topk) | |||||
self.save_evaluate_results = save_evaluate_results | |||||
self.save_object = save_object | |||||
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' | |||||
@rank_zero_call | |||||
def save_topk(self, trainer, results: Dict) -> Optional[str]: | |||||
""" | |||||
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。 | |||||
:param trainer: | |||||
:param results: | |||||
:return: | |||||
""" | |||||
if self.monitor is not None and self.topk_queue: | |||||
monitor_value = self.get_monitor_value(results) | |||||
if monitor_value is None: | |||||
return | |||||
key = f"{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||||
f"-{self.monitor_name}_{monitor_value}" | |||||
pop_key, pop_value = self.topk_queue.push(key, monitor_value if self.larger_better else -monitor_value) | |||||
if pop_key == key: # 说明不足以构成 topk,被退回了 | |||||
return None | |||||
folder = self.save(trainer, key) | |||||
if self.save_evaluate_results and folder: | |||||
try: | |||||
self.save_json(self.itemize_results(results), | |||||
os.path.join(folder, FASTNLP_EVALUATE_RESULT_FILENAME)) | |||||
except: | |||||
logger.exception(f"Fail to save evaluate results to {folder}") | |||||
if pop_key and pop_key != key: # 说明需要移除之前的 topk | |||||
self.rm(pop_key) | |||||
return folder | |||||
def save(self, trainer, folder_name): | |||||
fn = getattr(trainer, self.save_fn_name) | |||||
return super().save(fn, folder_name) | |||||
def state_dict(self): | |||||
states = { | |||||
'topk_queue': self.topk_queue.state_dict(), | |||||
'saver': self.saver.state_dict() | |||||
} | |||||
if isinstance(self._real_monitor, str): | |||||
states['_real_monitor'] = self._real_monitor | |||||
return states | |||||
def load_state_dict(self, states): | |||||
topk_queue_states = states['topk_queue'] | |||||
saver_states = states['saver'] | |||||
self.topk_queue.load_state_dict(topk_queue_states) | |||||
self.saver.load_state_dict(saver_states) | |||||
if '_real_monitor' in states: | |||||
self._real_monitor = states["_real_monitor"] | |||||
def __str__(self): | |||||
return f'topk-{self.topk_queue}#saver-{self.saver}#save_object-{self.save_object}' |
@@ -1,4 +1,6 @@ | |||||
from typing import Optional, Union | from typing import Optional, Union | ||||
import os | |||||
from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
from difflib import SequenceMatcher | from difflib import SequenceMatcher | ||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
@@ -15,7 +17,7 @@ def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str | |||||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | ||||
找到对应的 monitor | 找到对应的 monitor | ||||
""" | """ | ||||
if len(res)==0: | |||||
if len(res) == 0 or monitor is None: | |||||
return monitor, None | return monitor, None | ||||
if callable(monitor): | if callable(monitor): | ||||
@@ -56,4 +58,3 @@ def _match_length(a:str, b:str)->int: | |||||
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | ||||
return match.size | return match.size | ||||
@@ -38,7 +38,7 @@ class Evaluator: | |||||
driver: Union[str, Driver] = 'torch', | driver: Union[str, Driver] = 'torch', | ||||
device: Optional[Union[int, List[int], str]] = None, | device: Optional[Union[int, List[int], str]] = None, | ||||
batch_step_fn: Optional[callable] = None, | batch_step_fn: Optional[callable] = None, | ||||
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | |||||
evaluate_fn: Optional[str] = None, | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
model_wo_auto_param_call: bool = False, | model_wo_auto_param_call: bool = False, | ||||
@@ -57,8 +57,9 @@ class Evaluator: | |||||
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | ||||
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | ||||
batch_step_fn 函数。 | batch_step_fn 函数。 | ||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`; | |||||
默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | ||||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | ||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
@@ -69,6 +70,7 @@ class Evaluator: | |||||
:param kwargs: | :param kwargs: | ||||
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | ||||
与 batch normalization 将会关闭。默认为True。 | 与 batch normalization 将会关闭。默认为True。 | ||||
TODO 还没完成。 | |||||
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | ||||
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | ||||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | 当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | ||||
@@ -119,10 +121,6 @@ class Evaluator: | |||||
self._metric_wrapper = None | self._metric_wrapper = None | ||||
_ = self.metrics_wrapper # 触发检查 | _ = self.metrics_wrapper # 触发检查 | ||||
if self._dist_sampler is not None and not self.driver.is_distributed(): | |||||
logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause " | |||||
"different process evaluating on different data.") | |||||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | if evaluate_fn is not None and not isinstance(evaluate_fn, str): | ||||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | ||||
self._evaluate_step, self._evaluate_step_signature_fn = \ | self._evaluate_step, self._evaluate_step_signature_fn = \ | ||||
@@ -14,10 +14,10 @@ __all__ = [ | |||||
from .loops import Loop, TrainBatchLoop | from .loops import Loop, TrainBatchLoop | ||||
from .utils import State, TrainerState | from .utils import State, TrainerState | ||||
from .utils.utils import check_validate_every | |||||
from .utils.utils import check_evaluate_every | |||||
from .evaluator import Evaluator | from .evaluator import Evaluator | ||||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | ||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | |||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | |||||
from fastNLP.core.callbacks.callback import _CallbackWrapper | from fastNLP.core.callbacks.callback import _CallbackWrapper | ||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
@@ -26,7 +26,7 @@ from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nu | |||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | from fastNLP.core.utils.utils import _check_valid_parameters_number | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||||
from fastNLP.core.utils.exceptions import EarlyStopException | from fastNLP.core.utils.exceptions import EarlyStopException | ||||
@@ -94,9 +94,9 @@ class Trainer(TrainerEventTrigger): | |||||
evaluate_step 这个函数,如果没有则使用 forward 函数。 | evaluate_step 这个函数,如果没有则使用 forward 函数。 | ||||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | ||||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | ||||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次; | |||||
为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||||
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | |||||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | ||||
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | ||||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | ||||
@@ -124,7 +124,7 @@ class Trainer(TrainerEventTrigger): | |||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | ||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | ||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | ||||
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | ||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
@@ -214,13 +214,13 @@ class Trainer(TrainerEventTrigger): | |||||
""" 设置内部的 Evaluator """ | """ 设置内部的 Evaluator """ | ||||
if metrics is None and evaluate_dataloaders is not None: | if metrics is None and evaluate_dataloaders is not None: | ||||
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") | |||||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||||
if metrics is not None and evaluate_dataloaders is None: | if metrics is not None and evaluate_dataloaders is None: | ||||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") | |||||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | |||||
self.metrics = metrics | self.metrics = metrics | ||||
self.validate_every = evaluate_every | |||||
self.evaluate_every = evaluate_every | |||||
self.driver.setup() | self.driver.setup() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -235,7 +235,7 @@ class Trainer(TrainerEventTrigger): | |||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and evaluate_dataloaders is not None: | if metrics is not None and evaluate_dataloaders is not None: | ||||
check_validate_every(evaluate_every) | |||||
check_evaluate_every(evaluate_every) | |||||
self.evaluator = Evaluator( | self.evaluator = Evaluator( | ||||
model=model, | model=model, | ||||
dataloaders=evaluate_dataloaders, | dataloaders=evaluate_dataloaders, | ||||
@@ -248,7 +248,7 @@ class Trainer(TrainerEventTrigger): | |||||
output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
fp16=fp16, | fp16=fp16, | ||||
verbose=0, | verbose=0, | ||||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", None), | |||||
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), | |||||
progress_bar=kwargs.get('progress_bar', 'auto') | progress_bar=kwargs.get('progress_bar', 'auto') | ||||
) | ) | ||||
@@ -261,11 +261,14 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.set_deterministic_dataloader(self.dataloader) | self.driver.set_deterministic_dataloader(self.dataloader) | ||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | ||||
reproducible=self.callback_manager.has_trainer_checkpoint) | |||||
reproducible=self.callback_manager._need_reproducible_sampler) | |||||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | ||||
self.on_after_trainer_initialized(self.driver) | |||||
self.evaluate_batch_step_fn = evaluate_batch_step_fn | |||||
self.kwargs = kwargs | |||||
self.on_after_trainer_initialized(self.driver) | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | ||||
@@ -364,10 +367,10 @@ class Trainer(TrainerEventTrigger): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
if callable(self.validate_every): | |||||
if self.validate_every(self): | |||||
if callable(self.evaluate_every): | |||||
if self.evaluate_every(self): | |||||
self.run_evaluate() | self.run_evaluate() | ||||
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||||
elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: | |||||
self.run_evaluate() | self.run_evaluate() | ||||
def epoch_validate(self): | def epoch_validate(self): | ||||
@@ -377,8 +380,8 @@ class Trainer(TrainerEventTrigger): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
if isinstance(self.validate_every, int) and self.validate_every < 0: | |||||
validate_every = -self.validate_every | |||||
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | |||||
validate_every = -self.evaluate_every | |||||
if self.cur_epoch_idx % validate_every == 0: | if self.cur_epoch_idx % validate_every == 0: | ||||
self.run_evaluate() | self.run_evaluate() | ||||
@@ -427,7 +430,7 @@ class Trainer(TrainerEventTrigger): | |||||
self._custom_callbacks[None] = [] | self._custom_callbacks[None] = [] | ||||
if self.marker is not None: | if self.marker is not None: | ||||
if len(self._custom_callbacks[self.marker]) == 0: | if len(self._custom_callbacks[self.marker]) == 0: | ||||
print(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " | |||||
logger.info(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " | |||||
f"`{self.marker}` that is added through function `Trainer.on`") | f"`{self.marker}` that is added through function `Trainer.on`") | ||||
_own_callbacks += self._custom_callbacks[self.marker] | _own_callbacks += self._custom_callbacks[self.marker] | ||||
for each_callback in _own_callbacks: | for each_callback in _own_callbacks: | ||||
@@ -528,10 +531,10 @@ class Trainer(TrainerEventTrigger): | |||||
r""" | r""" | ||||
用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现; | 用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现; | ||||
:param folder: 保存模型的地址; | |||||
:param only_state_dict: 是否只保存模型的 `state_dict`; | |||||
:param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则在这个文件夹下创建 fastnlp_model.pkl.tar 文件。 | |||||
:param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 `state_dict`; | |||||
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | ||||
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它; | |||||
:param kwargs: | |||||
""" | """ | ||||
self.on_save_model() | self.on_save_model() | ||||
@@ -568,14 +571,19 @@ class Trainer(TrainerEventTrigger): | |||||
self.on_load_model() | self.on_load_model() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | if not isinstance(folder, (io.BytesIO, BinaryIO)): | ||||
if model_load_fn is not None: | |||||
if not callable(model_load_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
rank_zero_call(model_load_fn)(folder) | |||||
else: | |||||
if isinstance(folder, str): | |||||
folder = Path(folder) | |||||
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
try: | |||||
if model_load_fn is not None: | |||||
if not callable(model_load_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
rank_zero_call(model_load_fn)(folder) | |||||
else: | |||||
if isinstance(folder, str): | |||||
folder = Path(folder) | |||||
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
except FileNotFoundError as e: | |||||
if FASTNLP_MODEL_FILENAME not in os.listdir(folder): | |||||
logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.") | |||||
raise e | |||||
else: | else: | ||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | ||||
@@ -585,11 +593,13 @@ class Trainer(TrainerEventTrigger): | |||||
def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | ||||
r""" | r""" | ||||
用于断点重训 Trainer 的保存函数; | |||||
用于断点重训 Trainer 的保存函数。 | |||||
:param folder: | |||||
:param only_state_dict: | |||||
:param model_save_fn: | |||||
:param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 | |||||
如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件。 | |||||
:param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重。 | |||||
:param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder | |||||
参数),不必返回任何东西。 | |||||
:param kwargs: | :param kwargs: | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -602,17 +612,6 @@ class Trainer(TrainerEventTrigger): | |||||
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | ||||
} | } | ||||
# 3. validate filter state; | |||||
if self.evaluator is not None: | |||||
val_filter_state = {} | |||||
if hasattr(self.step_validate, "__fastNLP_filter__"): | |||||
val_filter_state["step_validate"] = self.step_validate.__fastNLP_filter__.state_dict() | |||||
if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||||
val_filter_state["epoch_validate"] = self.epoch_validate.__fastNLP_filter__.state_dict() | |||||
states["val_filter_state"] = val_filter_state | |||||
else: | |||||
states["val_filter_state"] = None | |||||
if isinstance(folder, str): | if isinstance(folder, str): | ||||
folder = Path(folder) | folder = Path(folder) | ||||
@@ -649,32 +648,30 @@ class Trainer(TrainerEventTrigger): | |||||
dataloader = self.dataloader | dataloader = self.dataloader | ||||
if not resume_training: | if not resume_training: | ||||
dataloader = None | dataloader = None | ||||
if model_load_fn is not None: | |||||
if not callable(model_load_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||||
rank_zero_call(model_load_fn)(folder) | |||||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||||
else: | |||||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||||
try: | |||||
if model_load_fn is not None: | |||||
if not callable(model_load_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||||
rank_zero_call(model_load_fn)(folder) | |||||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||||
else: | |||||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||||
except FileNotFoundError as e: | |||||
if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder): | |||||
logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.") | |||||
elif FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder): | |||||
logger.error(f"fastNLP Trainer checkpoint file:{FASTNLP_CHECKPOINT_FILENAME} is not found in {folder}.") | |||||
raise e | |||||
if not resume_training: | if not resume_training: | ||||
return | return | ||||
self.dataloader = states.pop('dataloader') | self.dataloader = states.pop('dataloader') | ||||
# 2. validate filter state; | |||||
if self.evaluator is not None: | |||||
val_filter_state = states["val_filter_state"] | |||||
if hasattr(self.step_validate, "__fastNLP_filter__"): | |||||
self.step_validate.__fastNLP_filter__.load_state_dict(val_filter_state["step_validate"]) | |||||
if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||||
self.epoch_validate.__fastNLP_filter__.load_state_dict(val_filter_state["epoch_validate"]) | |||||
# 3. 恢复 trainer_state 的状态; | |||||
# 1. 恢复 trainer_state 的状态; | |||||
self.trainer_state.load_state_dict(states["trainer_state"]) | self.trainer_state.load_state_dict(states["trainer_state"]) | ||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# 2. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | ||||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | ||||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | ||||
@@ -126,7 +126,7 @@ class _TruncatedDataLoader: | |||||
return getattr(self.dataloader, item) | return getattr(self.dataloader, item) | ||||
def check_validate_every(validate_every): | |||||
def check_evaluate_every(validate_every): | |||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | ||||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
if callable(validate_every): | if callable(validate_every): | ||||
@@ -11,6 +11,7 @@ from fastNLP.core.collators.collator import _MultiCollator | |||||
from fastNLP.core.utils.utils import indice_collate_wrapper | from fastNLP.core.utils.utils import indice_collate_wrapper | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
@@ -48,8 +49,8 @@ class TorchDataLoader(DataLoader): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size: int = 1, | def __init__(self, dataset, batch_size: int = 1, | ||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||||
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | num_workers: int = 0, collate_fn: Optional[Callable] = None, | ||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
@@ -380,7 +380,6 @@ class Driver(ABC): | |||||
""" | """ | ||||
# 单卡 driver 不需要这个函数; | # 单卡 driver 不需要这个函数; | ||||
if self._pids is not None: | if self._pids is not None: | ||||
exc_type, exc_value, exc_traceback_obj = sys.exc_info() | exc_type, exc_value, exc_traceback_obj = sys.exc_info() | ||||
_write_exc_info = { | _write_exc_info = { | ||||
'exc_type': str(exc_type.__name__), | 'exc_type': str(exc_type.__name__), | ||||
@@ -526,7 +526,7 @@ class TorchDDPDriver(TorchDriver): | |||||
def barrier(self): | def barrier(self): | ||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | ||||
torch.distributed.barrier(async_op=True) | |||||
torch.distributed.barrier(async_op=False) | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return True | return True | ||||
@@ -9,8 +9,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from pathlib import Path | from pathlib import Path | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch.utils.data import DataLoader, IterableDataset, RandomSampler, Sampler, BatchSampler, Dataset | |||||
from torch.utils.data import DataLoader, IterableDataset, Sampler, BatchSampler, Dataset | |||||
from torch.optim import Optimizer | from torch.optim import Optimizer | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||||
_reduces = { | _reduces = { | ||||
'sum': torch.max, | 'sum': torch.max, | ||||
'min': torch.min, | 'min': torch.min, | ||||
@@ -30,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -211,8 +212,8 @@ class TorchDriver(Driver): | |||||
states['sampler_states'] = sampler_states | states['sampler_states'] = sampler_states | ||||
else: | else: | ||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | |||||
'state.') | |||||
# 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
if should_save_model: | if should_save_model: | ||||
@@ -283,6 +284,9 @@ class TorchDriver(Driver): | |||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | elif isinstance(dataloader_args.sampler, ReproducibleSampler): | ||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
elif isinstance(dataloader_args.sampler, TorchRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
@@ -19,7 +19,7 @@ class Accuracy(Metric): | |||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | ||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | ||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | ||||
""" | """ | ||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
@@ -84,6 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
:param rank: | :param rank: | ||||
:return: | :return: | ||||
""" | """ | ||||
assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
f"number of samples({len(self.dataset)})." | |||||
assert num_replicas>0 and isinstance(num_replicas, int) | assert num_replicas>0 and isinstance(num_replicas, int) | ||||
assert isinstance(rank, int) and 0<=rank<num_replicas | assert isinstance(rank, int) and 0<=rank<num_replicas | ||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | ||||
@@ -24,8 +24,8 @@ __all__ = [ | |||||
'indice_collate_wrapper', | 'indice_collate_wrapper', | ||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'synchronize_safe_rm', | |||||
'synchronize_mkdir' | |||||
'rank_zero_rm', | |||||
'rank_zero_mkdir' | |||||
] | ] | ||||
from .cache_results import cache_results | from .cache_results import cache_results | ||||
@@ -37,6 +37,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device | |||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | ||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||||
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
@@ -38,8 +38,8 @@ __all__ = [ | |||||
'indice_collate_wrapper', | 'indice_collate_wrapper', | ||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'synchronize_safe_rm', | |||||
'synchronize_mkdir' | |||||
'rank_zero_rm', | |||||
'rank_zero_mkdir' | |||||
] | ] | ||||
@@ -629,7 +629,7 @@ def wait_filepath(path, exist=True): | |||||
def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||||
def rank_zero_rm(path: Optional[Union[str, Path]]): | |||||
""" | """ | ||||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | ||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | ||||
@@ -638,15 +638,14 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||||
:param path: | :param path: | ||||
:return: | :return: | ||||
""" | """ | ||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if not path.exists(): | |||||
return | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if not path.exists(): | |||||
return | |||||
_recursive_rm(path) | _recursive_rm(path) | ||||
wait_filepath(path, exist=False) | |||||
def _recursive_rm(path: Path): | def _recursive_rm(path: Path): | ||||
@@ -662,21 +661,19 @@ def _recursive_rm(path: Path): | |||||
path.rmdir() | path.rmdir() | ||||
def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||||
def rank_zero_mkdir(path: Optional[Union[str, Path]]): | |||||
""" | """ | ||||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | ||||
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | ||||
""" | """ | ||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
path.mkdir(parents=True, exist_ok=True) | |||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
wait_filepath(path, exist=True) | |||||
path.mkdir(parents=True, exist_ok=True) | |||||
def get_class_that_defined_method(method): | def get_class_that_defined_method(method): | ||||
@@ -49,7 +49,7 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||||
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | # 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | ||||
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | ||||
# todo 注释 直接使用的变量 | |||||
# 保存各种内容时的默认名称 | |||||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | ||||
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | ||||
FASTNLP_EVALUATE_RESULT_FILENAME = 'fastnlp_evaluate_results.json' |
@@ -7,13 +7,14 @@ from torch.optim import SGD | |||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from pathlib import Path | from pathlib import Path | ||||
import re | import re | ||||
import time | |||||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||||
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import synchronize_safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | from tests.helpers.datasets.torch_data import TorchArgMaxDatset | ||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
@@ -80,44 +81,21 @@ def test_model_checkpoint_callback_1( | |||||
version, | version, | ||||
only_state_dict | only_state_dict | ||||
): | ): | ||||
# def test_model_checkpoint_callback_1( | |||||
# model_and_optimizers: TrainerParameters, | |||||
# driver='torch_ddp', | |||||
# device=[0, 1], | |||||
# version=1, | |||||
# only_state_dict=True | |||||
# ): | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
try: | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
if version == 0: | |||||
callbacks = [ | |||||
ModelCheckpointCallback( | |||||
monitor="acc", | |||||
save_folder=path, | |||||
save_every_n_epochs=1, | |||||
save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
ModelCheckpointCallback( | |||||
monitor="acc", | |||||
save_folder=path, | |||||
save_every_n_epochs=3, | |||||
save_every_n_batches=None, | |||||
save_topk=2, | |||||
save_last=True, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
if version == 0: | |||||
callbacks = [ | |||||
CheckpointCallback(folder=path, every_n_epochs=1, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||||
monitor=None, only_state_dict=only_state_dict, save_object='model') | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
CheckpointCallback(folder=path, every_n_epochs=3, every_n_batches=None, last=True, on_exceptions=None, topk=2, | |||||
monitor="acc", only_state_dict=only_state_dict, save_object='model') | |||||
] | |||||
try: | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -134,7 +112,7 @@ def test_model_checkpoint_callback_1( | |||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
print("Finish train") | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
if version == 0: | if version == 0: | ||||
@@ -217,8 +195,7 @@ def test_model_checkpoint_callback_1( | |||||
trainer.run() | trainer.run() | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
pass | |||||
rank_zero_rm(path) | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -233,30 +210,23 @@ def test_model_checkpoint_callback_2( | |||||
device, | device, | ||||
only_state_dict | only_state_dict | ||||
): | ): | ||||
path = Path.cwd().joinpath("test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
try: | |||||
path = Path.cwd().joinpath("test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
from fastNLP.core.callbacks.callback_events import Events | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def raise_exception(trainer): | |||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||||
raise NotImplementedError | |||||
callbacks = [ | |||||
ModelCheckpointCallback( | |||||
monitor="acc1", | |||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_batches=None, | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=NotImplementedError, | |||||
only_state_dict=only_state_dict | |||||
), | |||||
] | |||||
from fastNLP.core.callbacks.callback_events import Events | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def raise_exception(trainer): | |||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||||
raise NotImplementedError | |||||
callbacks = [ | |||||
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False, | |||||
on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict, | |||||
save_object='model'), | |||||
] | |||||
try: | |||||
with pytest.raises(NotImplementedError): | with pytest.raises(NotImplementedError): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
@@ -315,14 +285,14 @@ def test_model_checkpoint_callback_2( | |||||
trainer.run() | trainer.run() | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
# pass | # pass | ||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -333,37 +303,21 @@ def test_trainer_checkpoint_callback_1( | |||||
version, | version, | ||||
only_state_dict | only_state_dict | ||||
): | ): | ||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
try: | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
if version == 0: | |||||
callbacks = [ | |||||
TrainerCheckpointCallback( | |||||
monitor="acc", | |||||
save_folder=path, | |||||
save_every_n_epochs=7, | |||||
save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
TrainerCheckpointCallback( | |||||
monitor="acc", | |||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_batches=None, | |||||
save_topk=2, | |||||
save_last=True, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
if version == 0: | |||||
callbacks = [ | |||||
CheckpointCallback(folder=path, every_n_epochs=7, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||||
monitor=None, only_state_dict=only_state_dict, save_object='trainer') | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=True, on_exceptions=None, | |||||
topk=2, monitor="acc", only_state_dict=only_state_dict, save_object='trainer') | |||||
] | |||||
try: | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -461,8 +415,7 @@ def test_trainer_checkpoint_callback_1( | |||||
trainer.run() | trainer.run() | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
pass | |||||
rank_zero_rm(path) | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -594,12 +547,12 @@ def test_trainer_checkpoint_callback_2( | |||||
callbacks = [ | callbacks = [ | ||||
TrainerCheckpointCallback( | TrainerCheckpointCallback( | ||||
monitor="acc", | monitor="acc", | ||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_batches=50, | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=None, | |||||
folder=path, | |||||
every_n_epochs=None, | |||||
every_n_batches=50, | |||||
topk=None, | |||||
last=False, | |||||
on_exception=None, | |||||
model_save_fn=model_save_fn | model_save_fn=model_save_fn | ||||
) | ) | ||||
] | ] | ||||
@@ -607,12 +560,12 @@ def test_trainer_checkpoint_callback_2( | |||||
callbacks = [ | callbacks = [ | ||||
TrainerCheckpointCallback( | TrainerCheckpointCallback( | ||||
monitor="acc", | monitor="acc", | ||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_batches=None, | |||||
save_topk=1, | |||||
save_last=True, | |||||
save_on_exception=None, | |||||
folder=path, | |||||
every_n_epochs=None, | |||||
every_n_batches=None, | |||||
topk=1, | |||||
last=True, | |||||
on_exception=None, | |||||
model_save_fn=model_save_fn | model_save_fn=model_save_fn | ||||
) | ) | ||||
] | ] | ||||
@@ -710,7 +663,7 @@ def test_trainer_checkpoint_callback_2( | |||||
trainer.run() | trainer.run() | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
# pass | # pass | ||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
@@ -0,0 +1,263 @@ | |||||
""" | |||||
测试 more_evaluate_callback | |||||
(1)能不能正确 evaluate ; | |||||
(2) 能不能保存 topk 并load进来进行训练 | |||||
""" | |||||
import pytest | |||||
import os | |||||
import pytest | |||||
from typing import Any | |||||
from dataclasses import dataclass | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from pathlib import Path | |||||
import re | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import rank_zero_rm | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.metrics import Metric | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.callbacks import MoreEvaluateCallback | |||||
@dataclass | |||||
class ArgMaxDatasetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
data_num: int = 100 | |||||
seed: int = 0 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
more_metrics: Any = None | |||||
@pytest.fixture(scope="module", params=[0], autouse=True) | |||||
def model_and_optimizers(request): | |||||
trainer_params = TrainerParameters() | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=ArgMaxDatasetConfig.num_labels, | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
dataset = TorchArgMaxDatset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||||
data_num=ArgMaxDatasetConfig.data_num, | |||||
seed=ArgMaxDatasetConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
class LossMetric(Metric): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.register_element('loss') | |||||
def update(self, loss): | |||||
self.loss += loss.item() | |||||
def get_metric(self) -> dict: | |||||
return self.loss.item() | |||||
trainer_params.train_dataloader = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {'loss': LossMetric()} | |||||
trainer_params.more_metrics = {"acc": Accuracy()} | |||||
return trainer_params | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | |||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||||
@magic_argv_env_context | |||||
def test_model_more_evaluate_callback_1( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
version, | |||||
only_state_dict | |||||
): | |||||
try: | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
if version == 0: | |||||
callbacks = [ | |||||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
metrics=model_and_optimizers.more_metrics, | |||||
evaluate_every=-1, | |||||
folder=path, topk=-1, | |||||
topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
metrics=model_and_optimizers.more_metrics, | |||||
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||||
save_object='model') | |||||
] | |||||
n_epochs = 5 | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all", | |||||
evaluate_fn='train_step' | |||||
) | |||||
trainer.run() | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
if version == 0: | |||||
assert len(all_saved_model_paths) == n_epochs | |||||
elif version == 1: | |||||
assert len(all_saved_model_paths) == 1 | |||||
for folder in all_saved_model_paths: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=2, | |||||
output_from_new_proc="all", | |||||
evaluate_fn='train_step' | |||||
) | |||||
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | |||||
finally: | |||||
rank_zero_rm(path) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | |||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||||
@magic_argv_env_context | |||||
def test_trainer_checkpoint_callback_1( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
version, | |||||
only_state_dict | |||||
): | |||||
try: | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
if version == 0: | |||||
callbacks = [ | |||||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
metrics=model_and_optimizers.more_metrics, | |||||
evaluate_every=-1, | |||||
folder=path, topk=-1, | |||||
topk_monitor='acc', only_state_dict=only_state_dict, save_object='trainer') | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
metrics=model_and_optimizers.more_metrics, | |||||
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||||
save_object='trainer') | |||||
] | |||||
n_epochs = 5 | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all", | |||||
evaluate_fn='train_step' | |||||
) | |||||
trainer.run() | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
if version == 0: | |||||
assert len(all_saved_model_paths) == n_epochs | |||||
elif version == 1: | |||||
assert len(all_saved_model_paths) == 1 | |||||
for folder in all_saved_model_paths: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=7, | |||||
output_from_new_proc="all", | |||||
evaluate_fn='train_step' | |||||
) | |||||
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | |||||
trainer.load(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | |||||
finally: | |||||
rank_zero_rm(path) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() |
@@ -15,7 +15,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
from fastNLP.core import synchronize_safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
@dataclass | @dataclass | ||||
@@ -239,7 +239,7 @@ def test_trainer_output_from_new_proc( | |||||
assert err_path.exists() | assert err_path.exists() | ||||
path = Path(os.path.abspath(output_from_new_proc)) | path = Path(os.path.abspath(output_from_new_proc)) | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | ||||
@@ -11,7 +11,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from fastNLP.core import synchronize_safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
import paddle | import paddle | ||||
from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
@@ -578,11 +578,11 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
finally: | finally: | ||||
if only_state_dict: | if only_state_dict: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
else: | else: | ||||
synchronize_safe_rm(path + ".pdiparams") | |||||
synchronize_safe_rm(path + ".pdiparams.info") | |||||
synchronize_safe_rm(path + ".pdmodel") | |||||
rank_zero_rm(path + ".pdiparams") | |||||
rank_zero_rm(path + ".pdiparams.info") | |||||
rank_zero_rm(path + ".pdmodel") | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
def test_save_and_load_with_randombatchsampler(only_state_dict): | def test_save_and_load_with_randombatchsampler(only_state_dict): | ||||
@@ -652,7 +652,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | ||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | assert len(left_y_batches | already_seen_y_set) == len(dataset) | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
def test_save_and_load_with_randomsampler(only_state_dict): | def test_save_and_load_with_randomsampler(only_state_dict): | ||||
@@ -730,4 +730,4 @@ def test_save_and_load_with_randomsampler(only_state_dict): | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | ||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | assert len(left_y_batches | already_seen_y_set) == len(dataset) | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) |
@@ -6,7 +6,7 @@ import logging | |||||
import re | import re | ||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.core import synchronize_safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
from tests.helpers.utils import magic_argv_env_context, recover_logger | from tests.helpers.utils import magic_argv_env_context, recover_logger | ||||
@@ -56,7 +56,7 @@ def test_add_file_ddp_1_torch(): | |||||
pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
synchronize_safe_rm(filepath) | |||||
rank_zero_rm(filepath) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -105,7 +105,7 @@ def test_add_file_ddp_2_torch(): | |||||
pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -155,7 +155,7 @@ def test_add_file_ddp_3_torch(): | |||||
pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
synchronize_safe_rm(file) | |||||
rank_zero_rm(file) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -202,7 +202,7 @@ def test_add_file_ddp_4_torch(): | |||||
pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -225,7 +225,7 @@ class TestLogger: | |||||
line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
@recover_logger | @recover_logger | ||||
def test_add_file_2(self): | def test_add_file_2(self): | ||||
@@ -243,7 +243,7 @@ class TestLogger: | |||||
line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(origin_path) | |||||
rank_zero_rm(origin_path) | |||||
@recover_logger | @recover_logger | ||||
def test_add_file_3(self): | def test_add_file_3(self): | ||||
@@ -279,7 +279,7 @@ class TestLogger: | |||||
line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
@recover_logger | @recover_logger | ||||
def test_stdout(self, capsys): | def test_stdout(self, capsys): | ||||
@@ -8,7 +8,7 @@ import sys | |||||
from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
from tests.helpers.common.utils import check_time_elapse | from tests.helpers.common.utils import check_time_elapse | ||||
from fastNLP.core import synchronize_safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
def get_subprocess_results(cmd): | def get_subprocess_results(cmd): | ||||
@@ -56,7 +56,7 @@ class TestCacheResults: | |||||
res = demo() | res = demo() | ||||
finally: | finally: | ||||
synchronize_safe_rm(cache_fp) | |||||
rank_zero_rm(cache_fp) | |||||
def test_cache_save_refresh(self): | def test_cache_save_refresh(self): | ||||
cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
@@ -70,7 +70,7 @@ class TestCacheResults: | |||||
with check_time_elapse(1, op='ge'): | with check_time_elapse(1, op='ge'): | ||||
res = demo() | res = demo() | ||||
finally: | finally: | ||||
synchronize_safe_rm(cache_fp) | |||||
rank_zero_rm(cache_fp) | |||||
def test_cache_no_func_change(self): | def test_cache_no_func_change(self): | ||||
cache_fp = os.path.abspath('demo.pkl') | cache_fp = os.path.abspath('demo.pkl') | ||||
@@ -91,7 +91,7 @@ class TestCacheResults: | |||||
with check_time_elapse(1, op='lt'): | with check_time_elapse(1, op='lt'): | ||||
res = demo() | res = demo() | ||||
finally: | finally: | ||||
synchronize_safe_rm('demo.pkl') | |||||
rank_zero_rm('demo.pkl') | |||||
def test_cache_func_change(self, capsys): | def test_cache_func_change(self, capsys): | ||||
cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
@@ -121,7 +121,7 @@ class TestCacheResults: | |||||
assert 'is different from its last cache' not in output[0] | assert 'is different from its last cache' not in output[0] | ||||
finally: | finally: | ||||
synchronize_safe_rm('demo.pkl') | |||||
rank_zero_rm('demo.pkl') | |||||
def test_cache_check_hash(self): | def test_cache_check_hash(self): | ||||
cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
@@ -152,7 +152,7 @@ class TestCacheResults: | |||||
assert 'is different from its last cache' in output[0] | assert 'is different from its last cache' in output[0] | ||||
finally: | finally: | ||||
synchronize_safe_rm('demo.pkl') | |||||
rank_zero_rm('demo.pkl') | |||||
# 外部 function 改变也会 导致改变 | # 外部 function 改变也会 导致改变 | ||||
def test_refer_fun_change(self): | def test_refer_fun_change(self): | ||||
@@ -177,7 +177,7 @@ class TestCacheResults: | |||||
assert 'is different from its last cache' in res | assert 'is different from its last cache' in res | ||||
finally: | finally: | ||||
synchronize_safe_rm(cache_fp) | |||||
rank_zero_rm(cache_fp) | |||||
# 外部 method 改变也会 导致改变 | # 外部 method 改变也会 导致改变 | ||||
def test_refer_class_method_change(self): | def test_refer_class_method_change(self): | ||||
@@ -202,7 +202,7 @@ class TestCacheResults: | |||||
assert 'is different from its last cache' in res | assert 'is different from its last cache' in res | ||||
finally: | finally: | ||||
synchronize_safe_rm(cache_fp) | |||||
rank_zero_rm(cache_fp) | |||||
def test_duplicate_keyword(self): | def test_duplicate_keyword(self): | ||||
with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
@@ -240,7 +240,7 @@ class TestCacheResults: | |||||
results = cache() | results = cache() | ||||
assert (1, 2) == results | assert (1, 2) == results | ||||
finally: | finally: | ||||
synchronize_safe_rm('demo/') | |||||
rank_zero_rm('demo/') | |||||
def test_result_none_error(self): | def test_result_none_error(self): | ||||
@cache_results('demo.pkl') | @cache_results('demo.pkl') | ||||
@@ -251,7 +251,7 @@ class TestCacheResults: | |||||
with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
results = cache() | results = cache() | ||||
finally: | finally: | ||||
synchronize_safe_rm('demo.pkl') | |||||
rank_zero_rm('demo.pkl') | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -2,7 +2,7 @@ import os | |||||
from fastNLP.envs.set_backend import dump_fastnlp_backend | from fastNLP.envs.set_backend import dump_fastnlp_backend | ||||
from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
from fastNLP.core import synchronize_safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
def test_dump_fastnlp_envs(): | def test_dump_fastnlp_envs(): | ||||
@@ -14,4 +14,4 @@ def test_dump_fastnlp_envs(): | |||||
assert filepath in output[0] | assert filepath in output[0] | ||||
assert os.path.exists(filepath) | assert os.path.exists(filepath) | ||||
finally: | finally: | ||||
synchronize_safe_rm(filepath) | |||||
rank_zero_rm(filepath) |
@@ -9,7 +9,7 @@ import numpy as np | |||||
from fastNLP.modules.mix_modules.mix_module import MixModule | from fastNLP.modules.mix_modules.mix_module import MixModule | ||||
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | ||||
from fastNLP.core import synchronize_safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
############################################################################ | ############################################################################ | ||||
@@ -227,7 +227,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
self.assertDictEqual(state_dict, new_state_dict) | self.assertDictEqual(state_dict, new_state_dict) | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | |||||
rank_zero_rm(path) | |||||
def if_device_correct(self, device): | def if_device_correct(self, device): | ||||