From aa95513055555a796bc9fe0ca61e1904b8fdc54e Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 19 Apr 2022 22:38:35 +0800 Subject: [PATCH] =?UTF-8?q?1.merge=20ModelCheckPointCallback=E5=92=8CTrain?= =?UTF-8?q?erCheckpointCallback;2.=E6=96=B0=E5=A2=9EMoreEvaluateCallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/__init__.py | 10 +- fastNLP/core/callbacks/callback.py | 11 +- fastNLP/core/callbacks/callback_manager.py | 8 +- fastNLP/core/callbacks/checkpoint_callback.py | 386 +++++------------- .../core/callbacks/has_monitor_callback.py | 97 +++-- .../callbacks/load_best_model_callback.py | 6 +- .../core/callbacks/more_evaluate_callback.py | 174 ++++++++ fastNLP/core/callbacks/progress_callback.py | 7 +- fastNLP/core/callbacks/topk_saver.py | 246 +++++++++++ fastNLP/core/callbacks/utils.py | 5 +- fastNLP/core/controllers/evaluator.py | 12 +- fastNLP/core/controllers/trainer.py | 125 +++--- fastNLP/core/controllers/utils/utils.py | 2 +- .../core/dataloaders/torch_dataloader/fdl.py | 5 +- fastNLP/core/drivers/driver.py | 1 - fastNLP/core/drivers/torch_driver/ddp.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 12 +- fastNLP/core/metrics/accuracy.py | 2 +- fastNLP/core/samplers/unrepeated_sampler.py | 2 + fastNLP/core/utils/__init__.py | 6 +- fastNLP/core/utils/utils.py | 33 +- fastNLP/envs/env.py | 4 +- .../test_checkpoint_callback_torch.py | 171 +++----- .../callbacks/test_more_evaluate_callback.py | 263 ++++++++++++ .../test_trainer_wo_evaluator_torch.py | 4 +- .../paddle_driver/test_single_device.py | 14 +- tests/core/log/test_logger.py | 16 +- tests/core/utils/test_cache_results.py | 20 +- tests/envs/test_set_backend.py | 4 +- tests/modules/mix_modules/test_mix_module.py | 4 +- 30 files changed, 1067 insertions(+), 585 deletions(-) create mode 100644 fastNLP/core/callbacks/more_evaluate_callback.py create mode 100644 fastNLP/core/callbacks/topk_saver.py create mode 100644 tests/core/callbacks/test_more_evaluate_callback.py diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index 58de0319..bbce73e0 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -4,8 +4,7 @@ __all__ = [ 'EventsList', 'Filter', 'CallbackManager', - 'ModelCheckpointCallback', - 'TrainerCheckpointCallback', + 'CheckpointCallback', 'choose_progress_callback', 'ProgressCallback', 'RichCallback', @@ -13,18 +12,21 @@ __all__ = [ 'LoadBestModelCallback', "EarlyStopCallback", + 'MoreEvaluateCallback', + "TorchWarmupCallback", - "TorchGradClipCallback" + "TorchGradClipCallback", ] from .callback import Callback from .callback_events import EventsList, Events, Filter from .callback_manager import CallbackManager -from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback +from .checkpoint_callback import CheckpointCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback from .early_stop_callback import EarlyStopCallback from .torch_callbacks import * +from .more_evaluate_callback import MoreEvaluateCallback diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 117cb524..1d3d1f11 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -236,7 +236,7 @@ class Callback: 结束 validate 时调用,并把 validate 的结果传入。 :param trainer: - :param results: + :param results: Evaluate 的结果,一般是个 dict 。 :return: """ pass @@ -250,6 +250,15 @@ class Callback: """ return self.__class__.__name__ + @property + def need_reproducible_sampler(self) -> bool: + """ + 当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。 + + :return: + """ + return False + class _CallbackWrapper(Callback): """ diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index a962fe9f..4aa822ad 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -8,7 +8,6 @@ __all__ = [ from .callback_events import Events from .callback import Callback -from .checkpoint_callback import TrainerCheckpointCallback from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.core.log import logger @@ -45,7 +44,7 @@ class CallbackManager: :param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; """ - self._has_trainer_checkpoint = False + self._need_reproducible_sampler = False _has_progress_callback = False _callbacks = [] @@ -98,8 +97,7 @@ class CallbackManager: :return: """ for each_callback in self.class_callbacks: - if isinstance(each_callback, TrainerCheckpointCallback): - self._has_trainer_checkpoint = True + self._need_reproducible_sampler |= each_callback.need_reproducible_sampler self.dissect_one_callback(each_callback) def dissect_one_callback(self, callback: Callback): @@ -211,7 +209,7 @@ class CallbackManager: @property def has_trainer_checkpoint(self) -> bool: - return self._has_trainer_checkpoint + return self._need_reproducible_sampler @_transfer def on_after_trainer_initialized(self, trainer): diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 2cb3510e..0f6dcd6a 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -1,339 +1,151 @@ __all__ = [ - 'ModelCheckpointCallback', - 'TrainerCheckpointCallback' + 'CheckpointCallback' ] -import os -from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping + +from typing import Union, Optional, Callable, Dict, Sequence from pathlib import Path import sys -from copy import deepcopy - -import fastNLP -from .has_monitor_callback import HasMonitorCallback from fastNLP.core.log import logger -from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK -from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir +from .topk_saver import TopkSaver +from .callback import Callback -class CheckpointCallback(HasMonitorCallback): - def __init__( - self, - monitor:Optional[Union[str, Callable]]=None, - save_folder: Optional[Union[str, Path]] = None, - save_every_n_epochs: Optional[int] = None, - save_every_n_batches: Optional[int] = None, - save_last: bool = False, - save_topk: Optional[int] = None, - save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, - larger_better: bool = True, - only_state_dict: bool = True, - model_save_fn: Optional[Callable] = None, - **kwargs, - ): +class CheckpointCallback(Callback): + def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, + every_n_batches: Optional[int] = None, last: bool = False, + on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0, + monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, + only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', + save_evaluate_results=True, **kwargs): """ - 请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 + 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 + + - folder/ + - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 + - {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 + - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 + - {save_object}-last/ # 最后一个 epoch 的保存 + - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 + - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 + + model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 + 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 - 果(字典类型),返回一个 float 值作为 monitor 的结果。 - :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 + :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 - :param save_every_n_epochs: 多少个 epoch 保存一次。 - :param save_every_n_batches: 多少个 batch 保存一次。 - :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 - :param save_topk: 保存 monitor 结果 topK 个。 - :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 + :param every_n_epochs: 多少个 epoch 保存一次。 + :param every_n_batches: 多少个 batch 保存一次。 + :param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 + :param topk: 保存 monitor 结果 topK 个。 + :param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。 :param larger_better: monitor 的值是否时越大越好。 :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 + :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 + fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 :param kwargs: """ - super().__init__(monitor=monitor, larger_better=larger_better, - must_have_monitor=save_topk is not None) - if save_folder is None: + super().__init__() + if folder is None: logger.warning( - "Parameter `path` is None, and we will use the current work directory to find and load your model.") - save_folder = Path.cwd() - save_folder = Path(save_folder) - if not save_folder.exists(): - raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") - elif save_folder.is_file(): - raise ValueError("Parameter `save_folder` should be a directory instead of a file.") - - if save_every_n_epochs is not None: - if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: - raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") - + "Parameter `folder` is None, and we will use the current work directory to find and load your model.") + folder = Path.cwd() + folder = Path(folder) + if not folder.exists(): + raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") + elif folder.is_file(): + raise ValueError("Parameter `folder` should be a directory instead of a file.") + + if every_n_epochs is not None: + if not isinstance(every_n_epochs, int) or every_n_epochs < 1: + raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.") else: - save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 + every_n_epochs = sys.maxsize # 使得没有数字可以整除 - if save_every_n_batches is not None: - if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: - raise ValueError( - "parameter save_every_n_batches should be an int and greater than or equal to 1.") + if every_n_batches is not None: + if not isinstance(every_n_batches, int) or every_n_batches < 1: + raise ValueError("Parameter `every_n_batches` should be an int and greater than or equal to 1.") else: - save_every_n_batches = sys.maxsize # 使得没有数字可以整除 + every_n_batches = sys.maxsize # 使得没有数字可以整除 - if save_topk is not None: - if not isinstance(save_topk, int) or save_topk < 1: - raise ValueError("parameter save_topk should be an int and greater than or equal to 1.") + if topk is not None: + if not isinstance(topk, int): + raise ValueError("Parameter `topk` should be an int.") + else: + topk = 0 - if save_on_exception is not None: - if not isinstance(save_on_exception, Sequence): - save_on_exception = [save_on_exception] + if on_exceptions is not None: + if not isinstance(on_exceptions, Sequence): + on_exceptions = [on_exceptions] - for exception in save_on_exception: + for exception in on_exceptions: if not issubclass(exception, BaseException): - raise TypeError("Each exception in parameter `save_on_exception` can only be " + raise TypeError("Each exception in parameter `on_exception` can only be " "`BaseException` type.") else: - save_on_exception = [] + on_exceptions = [] - self.save_folder = save_folder - self.save_every_n_epochs = save_every_n_epochs - self.save_every_n_batches = save_every_n_batches - self.save_last = save_last - self.save_topk = save_topk - self.only_state_dict = only_state_dict - self.model_save_fn = model_save_fn - self.save_on_exception = save_on_exception - self.kwargs = kwargs + self.topk_saver = TopkSaver(topk, monitor, larger_better, folder, only_state_dict, + model_save_fn, save_evaluate_results, + save_object, **kwargs) + self.topk = topk + self.save_object = save_object - # 这些参数是专门留给 topk 模式专门使用的; - self._topk_model = {} - self._topn = 0 # 表示目前已经保存了几个最好的模型; + self.every_n_epochs = every_n_epochs + self.every_n_batches = every_n_batches + self.last = last + self.exceptions = on_exceptions - # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, - # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; - self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) - # 该 folder 只在保存真的要发生的时候再创建。 + @property + def need_reproducible_sampler(self) -> bool: + return self.save_object == 'trainer' def on_after_trainer_initialized(self, trainer, driver): - if self.save_topk is not None: - super().on_after_trainer_initialized(trainer, driver) - if self.save_topk is not None and trainer.evaluator is None: - logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") + if self.topk_saver.topk_queue: # 需要设置 monitor + if self.topk_saver.monitor is None: + self.topk_saver.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) + if self.topk_saver.topk_queue and trainer.evaluator is None: + logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") def on_validate_end(self, trainer, results): - self._save_topk(trainer, results) + # 如果发生了保存,则返回的 folder 不为 None + folder = self.topk_saver.save_topk(trainer, results) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): - if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: - folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' - self.save(trainer, folder_name=folder_name) - if self.save_last: - folder_name = f'{self.folder_prefix}-last' - self.save(trainer, folder_name=folder_name) + if trainer.cur_epoch_idx % self.every_n_epochs == 0: + folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}' + self.topk_saver.save(trainer, folder_name=folder_name) + if self.last: + folder_name = f'{self.save_object}-last' + self.topk_saver.save(trainer, folder_name=folder_name) def on_train_batch_end(self, trainer): - if trainer.global_forward_batches % self.save_every_n_batches == 0: - folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' - self.save(trainer, folder_name=folder_name) + if trainer.global_forward_batches % self.every_n_batches == 0: + folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' + self.topk_saver.save(trainer, folder_name=folder_name) def on_exception(self, trainer, exception: BaseException): - if exception.__class__ in self.save_on_exception: - folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ - f'exception_{exception.__class__.__name__}' - self.save(trainer=trainer, folder_name=folder_name) + if exception.__class__ in self.exceptions: + folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ + f'exception_{exception.__class__.__name__}' + self.topk_saver.save(trainer, folder_name=folder_name) def on_save_checkpoint(self, trainer) -> Dict: """ - 保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 - topk_model的状态 - _real_monitor的值 + 保存状态,以便之后可以继续使用 """ states = {} - states['timestamp_path'] = str(self.timestamp_path.absolute()) - states['_topk_model'] = deepcopy(self._topk_model) - states['save_topk'] = 0 if self.save_topk is None else self.save_topk - if isinstance(self._real_monitor, str): - states['_real_monitor'] = self._real_monitor + states['topk_saver'] = self.topk_saver.state_dict() return states def on_load_checkpoint(self, trainer, states: Optional[Dict]): - timestamp_path = states['timestamp_path'] - if not os.path.exists(timestamp_path): - logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " - f" {self.timestamp_path.absolute()}.") - else: - logger.info(f"Resume to checkpoint in path: {timestamp_path}.") - self.timestamp_path = Path(timestamp_path) - _topk_model = states['_topk_model'] - save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) - if save_topk is not None and self.save_topk is not None: - assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ - f"as {save_topk}." - self._topk_model.update(self._topk_model) - - self._real_monitor = states["_real_monitor"] - - def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): - """ - 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 - - :param trainer: - :param results: - :return: - """ - if self.save_topk is not None: - monitor_value = self.get_monitor_value(results=results) - if monitor_value is None: - return - folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ - f"-{self._real_monitor}_{monitor_value}" - - _should_save = False - if self._topn < self.save_topk: - self._topk_model[folder_name] = monitor_value - self._topn += 1 - _should_save = True - else: - _least_valuable_model = (min if self.larger_better else max)(self._topk_model, - key=lambda x: self._topk_model[x]) - if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): - self._topk_model[folder_name] = monitor_value - _should_save = True - self._topk_model.pop(_least_valuable_model) - synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) - - assert len(self._topk_model) == self.save_topk == self._topn - - if _should_save: - self.save(trainer, folder_name=folder_name) - - def save(self, trainer, folder_name): - """ - 执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 - - :param trainer: - :param folder_name: - :return: - """ - folder = self.timestamp_path.joinpath(folder_name) - if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建 - synchronize_mkdir(folder) - _fn = getattr(trainer, self.save_fn_name) - _fn( - folder=folder, - only_state_dict=self.only_state_dict, - model_save_fn=self.model_save_fn, - **self.kwargs - ) - - @property - def folder_prefix(self): - raise NotImplementedError("The `folder_prefix` is not specified") - - @property - def save_fn_name(self): - raise NotImplementedError("The `save_fn_name` is not specified.") - - -class ModelCheckpointCallback(CheckpointCallback): - """ - 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 - - - save_folder/ - - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 - - model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 - - model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 - - model-last/ # 最后一个 epoch 的保存 - - model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 - - model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 + topk_saver_states = states['topk_saver'] + self.topk_saver.load_state_dict(topk_saver_states) - model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 - 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 - - :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 - 果(字典类型),返回一个 float 值作为 monitor 的结果。 - :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 - 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 - :param save_every_n_epochs: 多少个 epoch 保存一次。 - :param save_every_n_batches: 多少个 batch 保存一次。 - :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 - :param save_topk: 保存 monitor 结果 topK 个。 - :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 - :param larger_better: monitor 的值是否时越大越好。 - :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 - :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 - 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 - :param kwargs: - """ - @property - def save_fn_name(self): - """ - 调用 Trainer 中的哪个函数。 - - :return: - """ - return 'save_model' - - @property - def callback_name(self): - """ - 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; - :return: - """ - return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" - - @property - def folder_prefix(self): - return 'model' - - -class TrainerCheckpointCallback(CheckpointCallback): - """ - 保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 - - - save_folder/ - - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 - - trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 - - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 - - trainer-last/ # 最后一个 epoch 的保存 - - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 - - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 - - model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 - 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 - - :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 - 果(字典类型),返回一个 float 值作为 monitor 的结果。 - :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 - 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 - :param save_every_n_epochs: 多少个 epoch 保存一次。 - :param save_every_n_batches: 多少个 batch 保存一次。 - :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 - :param save_topk: 保存 monitor 结果 topK 个。 - :param save_on_exception: 在出异常信息时,是否保存。 - :param larger_better: monitor 的值是否时越大越好。 - :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 - :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 - 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 - :param kwargs: - """ - @property - def save_fn_name(self): - """ - 调用 Trainer 中的哪个函数。 - - :return: - """ - return 'save' - - @property - def callback_name(self): - """ - 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; - :return: - """ - - return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" - - @property - def folder_prefix(self): - return 'trainer' diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 74fa3aaf..b13f9dd6 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -1,10 +1,12 @@ __all__ = [ 'HasMonitorCallback', - 'ExecuteOnceBetterMonitor' + 'ExecuteOnceBetterMonitor', + 'MonitorUtility' ] from typing import Dict, Union, Any from abc import ABC +import functools from fastNLP.core.utils import apply_to_collection from fastNLP.core.callbacks import Callback @@ -27,21 +29,13 @@ class CanItemDataType(ABC): return NotImplemented +class MonitorUtility: + """ + 计算 monitor 的相关函数 -class HasMonitorCallback(Callback): - def __init__(self, monitor, larger_better, must_have_monitor=False): - """ - 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 - (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 - - :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 - 果(字典类型),返回一个 float 值作为 monitor 的结果。 - :param larger_better: monitor 是否时越大越好 - :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 - """ + """ + def __init__(self, monitor, larger_better): self.set_monitor(monitor, larger_better) - self.must_have_moinitor = must_have_monitor def set_monitor(self, monitor, larger_better): if callable(monitor): # 检查是否能够接受一个参数 @@ -57,26 +51,14 @@ class HasMonitorCallback(Callback): self.monitor_value = float('inf') self._real_monitor = self.monitor - def on_after_trainer_initialized(self, trainer, driver): + def itemize_results(self, results): """ - 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 - 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 + 将结果中有 .item() 方法的都调用一下,使得可以结果可以保存 - :param trainer: - :param driver: + :param results: :return: """ - if self.monitor is None and trainer.monitor is not None: - self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) - if self.must_have_moinitor and self.monitor is None: - raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " - f"You can set it in the initialization or through Trainer.") - - - def on_sanity_check_end(self, trainer, sanity_check_res): - # 主要核对一下 monitor 是否存在。 - if self.monitor is not None: - self.get_monitor_value(results=sanity_check_res) + return apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) def get_monitor_value(self, results:Dict)->Union[float, None]: """ @@ -85,10 +67,10 @@ class HasMonitorCallback(Callback): :param results: :return: 如果为 None ,表明此次没有找到合适的monitor """ - if len(results)==0: + if len(results) == 0 or self.monitor is None: return None # 保证所有的 tensor 都被转换为了 python 特定的类型 - results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) + results = self.itemize_results(results) use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=results) @@ -97,7 +79,7 @@ class HasMonitorCallback(Callback): # 第一次运行 if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " - f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") + f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") # 检测到此次和上次不同。 elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " @@ -165,7 +147,10 @@ class HasMonitorCallback(Callback): """ if callable(self.monitor): try: - monitor_name = self.monitor.__qualname__ + monitor = self.monitor + while isinstance(monitor, functools.partial): + monitor = monitor.func + monitor_name = monitor.__qualname__ except: monitor_name = self.monitor.__name__ elif self.monitor is None: @@ -176,6 +161,46 @@ class HasMonitorCallback(Callback): return monitor_name + +class HasMonitorCallback(MonitorUtility, Callback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + """ + 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 + (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 + + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 + :param larger_better: monitor 是否时越大越好 + :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 + """ + super().__init__(monitor, larger_better) + self.must_have_monitor = must_have_monitor + + def on_after_trainer_initialized(self, trainer, driver): + """ + 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 + 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 + + :param trainer: + :param driver: + :return: + """ + if self.monitor is None and trainer.monitor is not None: + self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) + if self.must_have_monitor and self.monitor is None: + raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " + f"You can set it in the initialization or through Trainer.") + if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None: + raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.__class__.__name__}" + f" need to watch the monitor:`{self.monitor_name}`.") + + def on_sanity_check_end(self, trainer, sanity_check_res): + # 主要核对一下 monitor 是否存在。 + if self.monitor is not None: + self.get_monitor_value(results=sanity_check_res) + + class ExecuteOnceBetterMonitor(HasMonitorCallback): def __init__(self, monitor, larger_better, execute_fn): """ @@ -183,13 +208,13 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 - 果(字典类型),返回一个 float 值作为 monitor 的结果。 + 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 :param larger_better: monitor 是否时越大越好 :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 """ super().__init__(monitor, larger_better, must_have_monitor=True) _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') - self.execute_fn = execute_fn() + self.execute_fn = execute_fn def on_validate_end(self, trainer, results): if self.is_better_results(results): diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 93b95667..0caf22d1 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -23,7 +23,7 @@ class LoadBestModelCallback(HasMonitorCallback): :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 - 果(字典类型),返回一个 float 值作为 monitor 的结果。 + 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 @@ -72,7 +72,7 @@ class LoadBestModelCallback(HasMonitorCallback): logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") except NotImplementedError: raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " - f"save best model when launch using script.") + f"save best model when launch using module.") super().on_after_trainer_initialized(trainer, driver) @@ -87,7 +87,7 @@ class LoadBestModelCallback(HasMonitorCallback): trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) def on_train_end(self, trainer): - logger.info(f"Loading best model with {self._real_monitor}: {self.monitor_value}...") + logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") if self.real_save_folder: trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_load_fn=self.model_load_fn) diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py new file mode 100644 index 00000000..c5d6133b --- /dev/null +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -0,0 +1,174 @@ +__all__ = [ + 'MoreEvaluateCallback' +] + +import os +from typing import Union, Callable, Optional, Dict + +from fastNLP.core.log import logger +from .has_monitor_callback import HasMonitorCallback +from .topk_saver import TopkSaver + + +class MoreEvaluateCallback(HasMonitorCallback): + def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, + watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, + evaluate_fn=None, num_eval_sanity_batch=2, + topk=0, topk_monitor=None, topk_larger_better=True, + folder=None, only_state_dict=True, save_object='model', model_save_fn=None, + save_evaluate_results=True, save_kwargs=None, + **kwargs): + """ + 当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 + 一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer + 无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 + topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 + + 如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存 + - folder/ + - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 + - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 + + :param dataloaders: 需要评估的数据 + :param metrics: 使用的 metrics 。 + :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch + evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 + 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 + :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 + 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 + 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 + 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor + 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 + :param watch_monitor_larger_better: watch_monitor 是否越大越好。 + :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 + `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 + 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 + :param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 + :param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 + evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 + :param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 + :param topk_larger_better: topk_monitor 的值是否时越大越好。 + :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 + :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 + :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 + 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 + fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 + :param save_kwargs: dict。更多的保存相关的参数。 + :param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。请特别留意 evaluate_fn 的设置。 + """ + super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, + must_have_monitor=False) + + if watch_monitor is None and evaluate_every is None: + raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") + if watch_monitor is not None and evaluate_every is not None: + raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") + self.watch_monitor = watch_monitor + + if topk_monitor is not None and topk == 0: + raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") + if topk != 0 and topk_monitor is None: + raise RuntimeError("`topk` is set, but `topk_monitor` is None.") + assert save_object in ['trainer', 'model'] + + self.dataloaders = dataloaders + self.metrics = metrics + self.evaluate_every = evaluate_every + self.evaluate_fn = evaluate_fn + self.num_eval_sanity_batch = num_eval_sanity_batch + if save_kwargs is None: + save_kwargs = {} + self.topk_saver = TopkSaver(topk=topk, monitor=topk_monitor, larger_better=topk_larger_better, + folder=folder, only_state_dict=only_state_dict, + model_save_fn=model_save_fn, save_evaluate_results=save_evaluate_results, + save_object=save_object, **save_kwargs) + self.kwargs = kwargs + + @property + def need_reproducible_sampler(self) -> bool: + return self.topk_saver.save_object == 'trainer' + + def on_after_trainer_initialized(self, trainer, driver): + # 如果是需要 watch 的,不能没有 evaluator + if self.watch_monitor is not None: + assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \ + f"evaluate_dataloaders is provided in Trainer." + + if trainer.evaluate_fn is self.evaluate_fn: + logger.warning_once("The `evaluate_fn` is the same as in Trainer, there seems no need to use " + "`MoreEvaluateCallback`.") + + # 初始化 evaluator , 同时避免调用 super 对 monitor 赋值 + kwargs = { + 'model': self.kwargs.get('model', trainer.model), + 'dataloaders': self.dataloaders, + 'metrics': self.metrics, + 'driver': self.kwargs.get('driver', trainer.driver), + 'device': self.kwargs.get('device', trainer.device), + 'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn), + 'evaluate_fn': self.evaluate_fn, + 'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), + 'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), + 'fp16': self.kwargs.get('fp16', trainer.fp16), + 'use_dist_sampler': self.kwargs.get('use_dist_sampler', + trainer.kwargs.get('eval_use_dist_sampler', None)), + 'progress_bar': self.kwargs.get('progress_bar', trainer.kwargs.get('progress_bar', 'auto')), + 'verbose': self.kwargs.get('verbose', 1) + } + + for key, value in self.kwargs.items(): + if key not in kwargs: + kwargs[key] = value + from fastNLP.core.controllers.evaluator import Evaluator + self.evaluator = Evaluator(**kwargs) + if self.num_eval_sanity_batch>0: + results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) + self.topk_saver.get_monitor_value(results) + + def on_validate_end(self, trainer, results): + if self.is_better_results(results, keep_if_better=True): + results = self.evaluator.run() + self.topk_saver.save_topk(trainer, results) + + def on_train_epoch_end(self, trainer): + if self.watch_monitor is not None: + return + if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: + validate_every = -self.evaluate_every + if trainer.cur_epoch_idx % validate_every == 0: + results = self.evaluator.run() + self.topk_saver.save_topk(trainer, results) + + def on_train_batch_end(self, trainer): + if self.watch_monitor is not None: + return + if callable(self.evaluate_every): + if self.evaluate_every(self): + results = self.evaluator.run() + self.topk_saver.save_topk(trainer, results) + elif self.evaluate_every > 0 and trainer.global_forward_batches % self.evaluate_every == 0: + results = self.evaluator.run() + self.topk_saver.save_topk(trainer, results) + + def on_save_checkpoint(self, trainer) -> Dict: + states = {'topk_saver': self.topk_saver.state_dict()} + if isinstance(self._real_monitor, str): + states['_real_monitor'] = self._real_monitor + states['monitor_value'] = self.monitor_value + return states + + def on_load_checkpoint(self, trainer, states: Optional[Dict]): + topk_saver_states = states['topk_saver'] + self.topk_saver.load_state_dict(topk_saver_states) + if '_real_monitor' in states: + self._real_monitor = states["_real_monitor"] + self.monitor_value = states['monitor_value'] + + @property + def callback_name(self): + metric_names = '+'.join(sorted(self.metrics.keys())) + return f'more_evaluate_callback#metric_name-{metric_names}#monitor-{self.monitor_name}#topk_saver:{self.topk_saver}' + diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index f3d5a435..64d72bd0 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -9,7 +9,6 @@ __all__ = [ ] from .has_monitor_callback import HasMonitorCallback -from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger @@ -42,7 +41,8 @@ class RichCallback(ProgressCallback): :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor - 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 + 相关的 monitor 值请返回 None 。 :param larger_better: 是否是 monitor 的结果越大越好。 :param format_json: 是否格式化 json 再打印 """ @@ -135,7 +135,8 @@ class RawTextCallback(ProgressCallback): :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor - 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 + 相关的 monitor 值请返回 None 。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py new file mode 100644 index 00000000..e2b0eb29 --- /dev/null +++ b/fastNLP/core/callbacks/topk_saver.py @@ -0,0 +1,246 @@ +import json +import os +from copy import deepcopy +from pathlib import Path +from typing import Optional, Dict, Tuple + +from fastNLP.core.utils import rank_zero_rm +from fastNLP.core.log import logger +from fastNLP.envs import FASTNLP_LAUNCH_TIME +from fastNLP.envs import rank_zero_call +from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME +from .has_monitor_callback import MonitorUtility + + +class Saver: + def __init__(self, folder, only_state_dict, model_save_fn, **kwargs): + """ + 执行保存的对象。保存的文件组织结构为 + - folder # 当前初始化的参数 + - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 + - folder_name # 由 save() 调用时传入。 + + :param folder: + :param only_state_dict: + :param model_save_fn: + :param kwargs: + """ + if folder is None: + logger.warning( + "Parameter `folder` is None, and we will use the current work directory to find and load your model.") + folder = Path.cwd() + folder = Path(folder) + if not folder.exists(): + raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") + elif folder.is_file(): + raise ValueError("Parameter `folder` should be a directory instead of a file.") + + self.folder = folder + self.only_state_dict = only_state_dict + self.model_save_fn = model_save_fn + self.kwargs = kwargs + self.eval_results = kwargs.get('eval_results', True) + self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) + + @rank_zero_call + def save(self, save_fn, folder_name): + """ + 执行保存的函数,将数据保存在 folder/timestamp/folder_name 下。其中 folder 为用户在初始化指定, + timestamp 为当前脚本的启动时间。 + + :param save_fn: 调用的保存函数,应该可接受参数 folder:str, only_state_dict: bool, model_save_fn: callable, kwargs + :param folder_name: 保存的 folder 名称,将被创建。 + :return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。 + """ + folder = self.timestamp_path.joinpath(folder_name) + folder.mkdir(parents=True, exist_ok=True) + save_fn( + folder=folder, + only_state_dict=self.only_state_dict, + model_save_fn=self.model_save_fn, + **self.kwargs + ) + return str(os.path.abspath(folder)) + + @rank_zero_call + def save_json(self, results, path): + """ + 以 json 格式保存 results 到 path 中 + + :param results: + :param path: + :return: + """ + with open(path, 'w', encoding='utf8') as f: + json.dump(results, f, indent=2) + + @rank_zero_call + def rm(self, folder_name): + """ + 移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。 + + :param folder_name: + :return: + """ + folder = self.timestamp_path.joinpath(folder_name) + rank_zero_rm(folder) + + def state_dict(self): + states = { + 'timestamp_path': str(self.timestamp_path), + } + return states + + def load_state_dict(self, states): + timestamp_path = states['timestamp_path'] + if not os.path.exists(timestamp_path): + logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to " + f" {self.timestamp_path.absolute()}.") + else: + logger.info(f"Resume to save checkpoint in path: {timestamp_path}.") + self.timestamp_path = Path(timestamp_path) + + def __str__(self): + return 'saver' # saver是无状态的,不需要有特定名字 + + +class TopkQueue: + def __init__(self, topk): + """ + 用于维护处于 topk 的 key, value 对。 + + :param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。 + """ + assert isinstance(topk, int) + self.topk = topk + self.topk_dict = {} # 其中 key 为保存的 + + def push(self, key, value) -> Optional[Tuple[str, float]]: + """ + 将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给 + 挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 + 推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。 + + :param str key: + :param float value: 如果为 None, 则不做任何操作。 + :return: (1)返回输入的 (key, value) ,说明不满足 topk; (2) 返回(None, None),说明满足 topk 且没有被挤出过去的记录; (3) + 返回非输入的 (key, value) , 说明输入满足 topk,且挤出了之前的记录。 + """ + if value is None: + return key, value + if self.topk < 0: + return None, None + if self.topk == 0: + return key, value + if len(self.topk_dict) value: + return key, value + else: + min_value = self.topk_dict.pop(min_key) + self.topk_dict[key] = value + return min_key, min_value + + def state_dict(self): + return deepcopy(self.topk_dict) + + def load_state_dict(self, states): + self.topk_dict.update(states) + + def __str__(self): + return f'topk-{self.topk}' + + def __bool__(self): + # 仅当 topk 为 0 时,表明该 topk_queue 无意义。 + return self.topk != 0 + + +class TopkSaver(MonitorUtility, Saver): + def __init__(self, topk, monitor, larger_better, folder, only_state_dict, + model_save_fn, save_evaluate_results, + save_object, **kwargs): + """ + 用来保存识别 tokp 模型并保存。 + + :param topk: + :param monitor: + :param larger_better: + :param folder: + :param only_state_dict: + :param model_save_fn: + :param save_evaluate_results: + :param save_object: + :param kwargs: + """ + MonitorUtility.__init__(self, monitor, larger_better) + Saver.__init__(self, folder, only_state_dict, model_save_fn, **kwargs) + + if monitor is not None and topk == 0: + raise RuntimeError("`monitor` is set, but `topk` is 0.") + if topk != 0 and monitor is None: + raise RuntimeError("`topk` is set, but `monitor` is None.") + + assert save_object in ['trainer', 'model'] + + self.saver = Saver(folder, only_state_dict, model_save_fn, **kwargs) + self.topk_queue = TopkQueue(topk) + self.save_evaluate_results = save_evaluate_results + self.save_object = save_object + self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' + + @rank_zero_call + def save_topk(self, trainer, results: Dict) -> Optional[str]: + """ + 根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。 + + :param trainer: + :param results: + :return: + """ + if self.monitor is not None and self.topk_queue: + monitor_value = self.get_monitor_value(results) + if monitor_value is None: + return + key = f"{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ + f"-{self.monitor_name}_{monitor_value}" + pop_key, pop_value = self.topk_queue.push(key, monitor_value if self.larger_better else -monitor_value) + if pop_key == key: # 说明不足以构成 topk,被退回了 + return None + folder = self.save(trainer, key) + if self.save_evaluate_results and folder: + try: + self.save_json(self.itemize_results(results), + os.path.join(folder, FASTNLP_EVALUATE_RESULT_FILENAME)) + except: + logger.exception(f"Fail to save evaluate results to {folder}") + + if pop_key and pop_key != key: # 说明需要移除之前的 topk + self.rm(pop_key) + return folder + + def save(self, trainer, folder_name): + fn = getattr(trainer, self.save_fn_name) + return super().save(fn, folder_name) + + def state_dict(self): + states = { + 'topk_queue': self.topk_queue.state_dict(), + 'saver': self.saver.state_dict() + } + if isinstance(self._real_monitor, str): + states['_real_monitor'] = self._real_monitor + + return states + + def load_state_dict(self, states): + topk_queue_states = states['topk_queue'] + saver_states = states['saver'] + self.topk_queue.load_state_dict(topk_queue_states) + self.saver.load_state_dict(saver_states) + if '_real_monitor' in states: + self._real_monitor = states["_real_monitor"] + + def __str__(self): + return f'topk-{self.topk_queue}#saver-{self.saver}#save_object-{self.save_object}' diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py index 7ece3bb9..e7ebb7aa 100644 --- a/fastNLP/core/callbacks/utils.py +++ b/fastNLP/core/callbacks/utils.py @@ -1,4 +1,6 @@ from typing import Optional, Union +import os + from fastNLP.core.log.logger import logger from difflib import SequenceMatcher from fastNLP.core.utils.utils import _get_fun_msg @@ -15,7 +17,7 @@ def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 找到对应的 monitor """ - if len(res)==0: + if len(res) == 0 or monitor is None: return monitor, None if callable(monitor): @@ -56,4 +58,3 @@ def _match_length(a:str, b:str)->int: match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) return match.size - diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 38522c9b..60703ef5 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -38,7 +38,7 @@ class Evaluator: driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, batch_step_fn: Optional[callable] = None, - evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable + evaluate_fn: Optional[str] = None, input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, @@ -57,8 +57,9 @@ class Evaluator: :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 batch_step_fn 函数。 - :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`; - 默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; + :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 + `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 + 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; @@ -69,6 +70,7 @@ class Evaluator: :param kwargs: bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout 与 batch normalization 将会关闭。默认为True。 + TODO 还没完成。 Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, 当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 @@ -119,10 +121,6 @@ class Evaluator: self._metric_wrapper = None _ = self.metrics_wrapper # 触发检查 - if self._dist_sampler is not None and not self.driver.is_distributed(): - logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause " - "different process evaluating on different data.") - if evaluate_fn is not None and not isinstance(evaluate_fn, str): raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") self._evaluate_step, self._evaluate_step_signature_fn = \ diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 40ec635d..779d3d83 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -14,10 +14,10 @@ __all__ = [ from .loops import Loop, TrainBatchLoop from .utils import State, TrainerState -from .utils.utils import check_validate_every +from .utils.utils import check_evaluate_every from .evaluator import Evaluator from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader -from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter +from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.drivers import Driver @@ -26,7 +26,7 @@ from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nu from fastNLP.core.utils.utils import _check_valid_parameters_number from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger -from fastNLP.envs import FASTNLP_MODEL_FILENAME +from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.utils.exceptions import EarlyStopException @@ -94,9 +94,9 @@ class Trainer(TrainerEventTrigger): evaluate_step 这个函数,如果没有则使用 forward 函数。 :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; - :param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; - 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 - 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 + :param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次; + 为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 + 返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 @@ -124,7 +124,7 @@ class Trainer(TrainerEventTrigger): set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; + eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; @@ -214,13 +214,13 @@ class Trainer(TrainerEventTrigger): """ 设置内部的 Evaluator """ if metrics is None and evaluate_dataloaders is not None: - raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") + raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") if metrics is not None and evaluate_dataloaders is None: - raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") + raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") self.metrics = metrics - self.validate_every = evaluate_every + self.evaluate_every = evaluate_every self.driver.setup() self.driver.barrier() @@ -235,7 +235,7 @@ class Trainer(TrainerEventTrigger): self.monitor = monitor self.larger_better = larger_better if metrics is not None and evaluate_dataloaders is not None: - check_validate_every(evaluate_every) + check_evaluate_every(evaluate_every) self.evaluator = Evaluator( model=model, dataloaders=evaluate_dataloaders, @@ -248,7 +248,7 @@ class Trainer(TrainerEventTrigger): output_mapping=output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("use_eval_dist_sampler", None), + use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), progress_bar=kwargs.get('progress_bar', 'auto') ) @@ -261,11 +261,14 @@ class Trainer(TrainerEventTrigger): self.driver.set_deterministic_dataloader(self.dataloader) self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, - reproducible=self.callback_manager.has_trainer_checkpoint) + reproducible=self.callback_manager._need_reproducible_sampler) self.set_grad_to_none = kwargs.get("set_grad_to_none", True) - self.on_after_trainer_initialized(self.driver) + self.evaluate_batch_step_fn = evaluate_batch_step_fn + self.kwargs = kwargs + + self.on_after_trainer_initialized(self.driver) self.driver.barrier() def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, @@ -364,10 +367,10 @@ class Trainer(TrainerEventTrigger): :return: """ if self.evaluator is not None: - if callable(self.validate_every): - if self.validate_every(self): + if callable(self.evaluate_every): + if self.evaluate_every(self): self.run_evaluate() - elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: + elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: self.run_evaluate() def epoch_validate(self): @@ -377,8 +380,8 @@ class Trainer(TrainerEventTrigger): :return: """ if self.evaluator is not None: - if isinstance(self.validate_every, int) and self.validate_every < 0: - validate_every = -self.validate_every + if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: + validate_every = -self.evaluate_every if self.cur_epoch_idx % validate_every == 0: self.run_evaluate() @@ -427,7 +430,7 @@ class Trainer(TrainerEventTrigger): self._custom_callbacks[None] = [] if self.marker is not None: if len(self._custom_callbacks[self.marker]) == 0: - print(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " + logger.info(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " f"`{self.marker}` that is added through function `Trainer.on`") _own_callbacks += self._custom_callbacks[self.marker] for each_callback in _own_callbacks: @@ -528,10 +531,10 @@ class Trainer(TrainerEventTrigger): r""" 用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现; - :param folder: 保存模型的地址; - :param only_state_dict: 是否只保存模型的 `state_dict`; + :param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则在这个文件夹下创建 fastnlp_model.pkl.tar 文件。 + :param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 `state_dict`; :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; - :param kwargs: 一些 driver 的保存模型的函数的参数另有其它; + :param kwargs: """ self.on_save_model() @@ -568,14 +571,19 @@ class Trainer(TrainerEventTrigger): self.on_load_model() self.driver.barrier() if not isinstance(folder, (io.BytesIO, BinaryIO)): - if model_load_fn is not None: - if not callable(model_load_fn): - raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") - rank_zero_call(model_load_fn)(folder) - else: - if isinstance(folder, str): - folder = Path(folder) - self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) + try: + if model_load_fn is not None: + if not callable(model_load_fn): + raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") + rank_zero_call(model_load_fn)(folder) + else: + if isinstance(folder, str): + folder = Path(folder) + self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) + except FileNotFoundError as e: + if FASTNLP_MODEL_FILENAME not in os.listdir(folder): + logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.") + raise e else: if model_load_fn is not None: raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " @@ -585,11 +593,13 @@ class Trainer(TrainerEventTrigger): def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): r""" - 用于断点重训 Trainer 的保存函数; + 用于断点重训 Trainer 的保存函数。 - :param folder: - :param only_state_dict: - :param model_save_fn: + :param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 + 如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件。 + :param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重。 + :param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder + 参数),不必返回任何东西。 :param kwargs: :return: """ @@ -602,17 +612,6 @@ class Trainer(TrainerEventTrigger): 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) } - # 3. validate filter state; - if self.evaluator is not None: - val_filter_state = {} - if hasattr(self.step_validate, "__fastNLP_filter__"): - val_filter_state["step_validate"] = self.step_validate.__fastNLP_filter__.state_dict() - if hasattr(self.epoch_validate, "__fastNLP_filter__"): - val_filter_state["epoch_validate"] = self.epoch_validate.__fastNLP_filter__.state_dict() - states["val_filter_state"] = val_filter_state - else: - states["val_filter_state"] = None - if isinstance(folder, str): folder = Path(folder) @@ -649,32 +648,30 @@ class Trainer(TrainerEventTrigger): dataloader = self.dataloader if not resume_training: dataloader = None - - if model_load_fn is not None: - if not callable(model_load_fn): - raise ValueError("Parameter `model_save_fn` should be `Callable`.") - rank_zero_call(model_load_fn)(folder) - states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) - else: - states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) + try: + if model_load_fn is not None: + if not callable(model_load_fn): + raise ValueError("Parameter `model_save_fn` should be `Callable`.") + rank_zero_call(model_load_fn)(folder) + states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) + else: + states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) + except FileNotFoundError as e: + if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder): + logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.") + elif FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder): + logger.error(f"fastNLP Trainer checkpoint file:{FASTNLP_CHECKPOINT_FILENAME} is not found in {folder}.") + raise e if not resume_training: return self.dataloader = states.pop('dataloader') - # 2. validate filter state; - if self.evaluator is not None: - val_filter_state = states["val_filter_state"] - if hasattr(self.step_validate, "__fastNLP_filter__"): - self.step_validate.__fastNLP_filter__.load_state_dict(val_filter_state["step_validate"]) - if hasattr(self.epoch_validate, "__fastNLP_filter__"): - self.epoch_validate.__fastNLP_filter__.load_state_dict(val_filter_state["epoch_validate"]) - - # 3. 恢复 trainer_state 的状态; + # 1. 恢复 trainer_state 的状态; self.trainer_state.load_state_dict(states["trainer_state"]) - # 4. 修改 trainer_state.batch_idx_in_epoch + # 2. 修改 trainer_state.batch_idx_in_epoch # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index 3d25fd6b..cc7a1b66 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -126,7 +126,7 @@ class _TruncatedDataLoader: return getattr(self.dataloader, item) -def check_validate_every(validate_every): +def check_evaluate_every(validate_every): if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") if callable(validate_every): diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 8eeea1f4..cf8e2c31 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -11,6 +11,7 @@ from fastNLP.core.collators.collator import _MultiCollator from fastNLP.core.utils.utils import indice_collate_wrapper from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler @@ -48,8 +49,8 @@ class TorchDataLoader(DataLoader): """ def __init__(self, dataset, batch_size: int = 1, - shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, - batch_sampler: Optional["Sampler[Sequence[int]]"] = None, + shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Optional[Callable] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 1a810865..03e3667e 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -380,7 +380,6 @@ class Driver(ABC): """ # 单卡 driver 不需要这个函数; if self._pids is not None: - exc_type, exc_value, exc_traceback_obj = sys.exc_info() _write_exc_info = { 'exc_type': str(exc_type.__name__), diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index d68b6a0d..0fca3856 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -526,7 +526,7 @@ class TorchDDPDriver(TorchDriver): def barrier(self): if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 - torch.distributed.barrier(async_op=True) + torch.distributed.barrier(async_op=False) def is_distributed(self): return True diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 5638b4c6..8e37f550 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -9,8 +9,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH from pathlib import Path if _NEED_IMPORT_TORCH: import torch - from torch.utils.data import DataLoader, IterableDataset, RandomSampler, Sampler, BatchSampler, Dataset + from torch.utils.data import DataLoader, IterableDataset, Sampler, BatchSampler, Dataset from torch.optim import Optimizer + from torch.utils.data import RandomSampler as TorchRandomSampler _reduces = { 'sum': torch.max, 'min': torch.min, @@ -30,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler class TorchDriver(Driver): @@ -211,8 +212,8 @@ class TorchDriver(Driver): states['sampler_states'] = sampler_states else: - raise RuntimeError( - 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') + raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' + 'state.') # 2. 保存模型的状态; if should_save_model: @@ -283,6 +284,9 @@ class TorchDriver(Driver): sampler = dataloader_args.batch_sampler elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler + elif isinstance(dataloader_args.sampler, TorchRandomSampler): + sampler = RandomSampler(dataloader_args.sampler.data_source) + logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") elif self.is_distributed(): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " "`ReproducibleSampler`.") diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index 8b3889d0..d9ccb332 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -19,7 +19,7 @@ class Accuracy(Metric): :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 - :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, + :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 """ super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index 02ec1162..809d0d38 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -84,6 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :param rank: :return: """ + assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ + f"number of samples({len(self.dataset)})." assert num_replicas>0 and isinstance(num_replicas, int) assert isinstance(rank, int) and 0<=rank dict: + return self.loss.item() + + trainer_params.train_dataloader = _dataloader + trainer_params.evaluate_dataloaders = _dataloader + trainer_params.metrics = {'loss': LossMetric()} + + trainer_params.more_metrics = {"acc": Accuracy()} + + return trainer_params + + +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("version", [0, 1]) +@pytest.mark.parametrize("only_state_dict", [True, False]) +@magic_argv_env_context +def test_model_more_evaluate_callback_1( + model_and_optimizers: TrainerParameters, + driver, + device, + version, + only_state_dict +): + try: + path = Path.cwd().joinpath(f"test_model_checkpoint") + path.mkdir(exist_ok=True, parents=True) + + if version == 0: + callbacks = [ + MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, + metrics=model_and_optimizers.more_metrics, + evaluate_every=-1, + folder=path, topk=-1, + topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') + ] + elif version == 1: + callbacks = [ + MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, + metrics=model_and_optimizers.more_metrics, + evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, + folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, + save_object='model') + ] + n_epochs = 5 + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=n_epochs, + callbacks=callbacks, + output_from_new_proc="all", + evaluate_fn='train_step' + ) + + trainer.run() + + all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} + # 检查生成保存模型文件的数量是不是正确的; + if version == 0: + assert len(all_saved_model_paths) == n_epochs + elif version == 1: + assert len(all_saved_model_paths) == 1 + + for folder in all_saved_model_paths: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=2, + output_from_new_proc="all", + evaluate_fn='train_step' + ) + folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) + trainer.load_model(folder, only_state_dict=only_state_dict) + + trainer.run() + + finally: + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("version", [0, 1]) +@pytest.mark.parametrize("only_state_dict", [True, False]) +@magic_argv_env_context +def test_trainer_checkpoint_callback_1( + model_and_optimizers: TrainerParameters, + driver, + device, + version, + only_state_dict +): + try: + path = Path.cwd().joinpath(f"test_model_checkpoint") + path.mkdir(exist_ok=True, parents=True) + + if version == 0: + callbacks = [ + MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, + metrics=model_and_optimizers.more_metrics, + evaluate_every=-1, + folder=path, topk=-1, + topk_monitor='acc', only_state_dict=only_state_dict, save_object='trainer') + ] + elif version == 1: + callbacks = [ + MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, + metrics=model_and_optimizers.more_metrics, + evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, + folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, + save_object='trainer') + ] + n_epochs = 5 + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=n_epochs, + callbacks=callbacks, + output_from_new_proc="all", + evaluate_fn='train_step' + ) + + trainer.run() + + all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} + # 检查生成保存模型文件的数量是不是正确的; + if version == 0: + assert len(all_saved_model_paths) == n_epochs + elif version == 1: + assert len(all_saved_model_paths) == 1 + + for folder in all_saved_model_paths: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=7, + output_from_new_proc="all", + evaluate_fn='train_step' + ) + folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) + trainer.load(folder, only_state_dict=only_state_dict) + + trainer.run() + + finally: + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 43fdfc3d..a0cdcb22 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -15,7 +15,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing -from fastNLP.core import synchronize_safe_rm +from fastNLP.core import rank_zero_rm @dataclass @@ -239,7 +239,7 @@ def test_trainer_output_from_new_proc( assert err_path.exists() path = Path(os.path.abspath(output_from_new_proc)) - synchronize_safe_rm(path) + rank_zero_rm(path) @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 79527f39..0c8e4256 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -11,7 +11,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from fastNLP.core import synchronize_safe_rm +from fastNLP.core import rank_zero_rm import paddle from paddle.io import DataLoader, BatchSampler @@ -578,11 +578,11 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): assert paddle.equal_all(res1["pred"], res2["pred"]) finally: if only_state_dict: - synchronize_safe_rm(path) + rank_zero_rm(path) else: - synchronize_safe_rm(path + ".pdiparams") - synchronize_safe_rm(path + ".pdiparams.info") - synchronize_safe_rm(path + ".pdmodel") + rank_zero_rm(path + ".pdiparams") + rank_zero_rm(path + ".pdiparams.info") + rank_zero_rm(path + ".pdmodel") @pytest.mark.parametrize("only_state_dict", ([True, False])) def test_save_and_load_with_randombatchsampler(only_state_dict): @@ -652,7 +652,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) assert len(left_y_batches | already_seen_y_set) == len(dataset) finally: - synchronize_safe_rm(path) + rank_zero_rm(path) @pytest.mark.parametrize("only_state_dict", ([True, False])) def test_save_and_load_with_randomsampler(only_state_dict): @@ -730,4 +730,4 @@ def test_save_and_load_with_randomsampler(only_state_dict): assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) assert len(left_y_batches | already_seen_y_set) == len(dataset) finally: - synchronize_safe_rm(path) + rank_zero_rm(path) diff --git a/tests/core/log/test_logger.py b/tests/core/log/test_logger.py index 4fe49bef..7c1b96e7 100644 --- a/tests/core/log/test_logger.py +++ b/tests/core/log/test_logger.py @@ -6,7 +6,7 @@ import logging import re from fastNLP.envs.env import FASTNLP_LAUNCH_TIME -from fastNLP.core import synchronize_safe_rm +from fastNLP.core import rank_zero_rm from fastNLP.core.log.logger import logger from tests.helpers.utils import magic_argv_env_context, recover_logger @@ -56,7 +56,7 @@ def test_add_file_ddp_1_torch(): pattern = re.compile(msg) assert len(pattern.findall(line)) == 1 - synchronize_safe_rm(filepath) + rank_zero_rm(filepath) dist.barrier() dist.destroy_process_group() @@ -105,7 +105,7 @@ def test_add_file_ddp_2_torch(): pattern = re.compile(msg) assert len(pattern.findall(line)) == 1 finally: - synchronize_safe_rm(path) + rank_zero_rm(path) dist.barrier() dist.destroy_process_group() @@ -155,7 +155,7 @@ def test_add_file_ddp_3_torch(): pattern = re.compile(msg) assert len(pattern.findall(line)) == 1 - synchronize_safe_rm(file) + rank_zero_rm(file) dist.barrier() dist.destroy_process_group() @@ -202,7 +202,7 @@ def test_add_file_ddp_4_torch(): pattern = re.compile(msg) assert len(pattern.findall(line)) == 1 finally: - synchronize_safe_rm(path) + rank_zero_rm(path) dist.barrier() dist.destroy_process_group() @@ -225,7 +225,7 @@ class TestLogger: line = ''.join([l for l in f]) assert self.msg in line finally: - synchronize_safe_rm(path) + rank_zero_rm(path) @recover_logger def test_add_file_2(self): @@ -243,7 +243,7 @@ class TestLogger: line = ''.join([l for l in f]) assert self.msg in line finally: - synchronize_safe_rm(origin_path) + rank_zero_rm(origin_path) @recover_logger def test_add_file_3(self): @@ -279,7 +279,7 @@ class TestLogger: line = ''.join([l for l in f]) assert self.msg in line finally: - synchronize_safe_rm(path) + rank_zero_rm(path) @recover_logger def test_stdout(self, capsys): diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py index 64303f70..b652ff70 100644 --- a/tests/core/utils/test_cache_results.py +++ b/tests/core/utils/test_cache_results.py @@ -8,7 +8,7 @@ import sys from fastNLP.core.utils.cache_results import cache_results from tests.helpers.common.utils import check_time_elapse -from fastNLP.core import synchronize_safe_rm +from fastNLP.core import rank_zero_rm def get_subprocess_results(cmd): @@ -56,7 +56,7 @@ class TestCacheResults: res = demo() finally: - synchronize_safe_rm(cache_fp) + rank_zero_rm(cache_fp) def test_cache_save_refresh(self): cache_fp = 'demo.pkl' @@ -70,7 +70,7 @@ class TestCacheResults: with check_time_elapse(1, op='ge'): res = demo() finally: - synchronize_safe_rm(cache_fp) + rank_zero_rm(cache_fp) def test_cache_no_func_change(self): cache_fp = os.path.abspath('demo.pkl') @@ -91,7 +91,7 @@ class TestCacheResults: with check_time_elapse(1, op='lt'): res = demo() finally: - synchronize_safe_rm('demo.pkl') + rank_zero_rm('demo.pkl') def test_cache_func_change(self, capsys): cache_fp = 'demo.pkl' @@ -121,7 +121,7 @@ class TestCacheResults: assert 'is different from its last cache' not in output[0] finally: - synchronize_safe_rm('demo.pkl') + rank_zero_rm('demo.pkl') def test_cache_check_hash(self): cache_fp = 'demo.pkl' @@ -152,7 +152,7 @@ class TestCacheResults: assert 'is different from its last cache' in output[0] finally: - synchronize_safe_rm('demo.pkl') + rank_zero_rm('demo.pkl') # 外部 function 改变也会 导致改变 def test_refer_fun_change(self): @@ -177,7 +177,7 @@ class TestCacheResults: assert 'is different from its last cache' in res finally: - synchronize_safe_rm(cache_fp) + rank_zero_rm(cache_fp) # 外部 method 改变也会 导致改变 def test_refer_class_method_change(self): @@ -202,7 +202,7 @@ class TestCacheResults: assert 'is different from its last cache' in res finally: - synchronize_safe_rm(cache_fp) + rank_zero_rm(cache_fp) def test_duplicate_keyword(self): with pytest.raises(RuntimeError): @@ -240,7 +240,7 @@ class TestCacheResults: results = cache() assert (1, 2) == results finally: - synchronize_safe_rm('demo/') + rank_zero_rm('demo/') def test_result_none_error(self): @cache_results('demo.pkl') @@ -251,7 +251,7 @@ class TestCacheResults: with pytest.raises(RuntimeError): results = cache() finally: - synchronize_safe_rm('demo.pkl') + rank_zero_rm('demo.pkl') if __name__ == '__main__': diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 03931bdc..395c854d 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -2,7 +2,7 @@ import os from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing -from fastNLP.core import synchronize_safe_rm +from fastNLP.core import rank_zero_rm def test_dump_fastnlp_envs(): @@ -14,4 +14,4 @@ def test_dump_fastnlp_envs(): assert filepath in output[0] assert os.path.exists(filepath) finally: - synchronize_safe_rm(filepath) + rank_zero_rm(filepath) diff --git a/tests/modules/mix_modules/test_mix_module.py b/tests/modules/mix_modules/test_mix_module.py index ae249c74..6025540b 100644 --- a/tests/modules/mix_modules/test_mix_module.py +++ b/tests/modules/mix_modules/test_mix_module.py @@ -9,7 +9,7 @@ import numpy as np from fastNLP.modules.mix_modules.mix_module import MixModule from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle -from fastNLP.core import synchronize_safe_rm +from fastNLP.core import rank_zero_rm ############################################################################ @@ -227,7 +227,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): self.assertDictEqual(state_dict, new_state_dict) finally: - synchronize_safe_rm(path) + rank_zero_rm(path) def if_device_correct(self, device):