| @@ -4,8 +4,7 @@ __all__ = [ | |||||
| 'EventsList', | 'EventsList', | ||||
| 'Filter', | 'Filter', | ||||
| 'CallbackManager', | 'CallbackManager', | ||||
| 'ModelCheckpointCallback', | |||||
| 'TrainerCheckpointCallback', | |||||
| 'CheckpointCallback', | |||||
| 'choose_progress_callback', | 'choose_progress_callback', | ||||
| 'ProgressCallback', | 'ProgressCallback', | ||||
| 'RichCallback', | 'RichCallback', | ||||
| @@ -13,18 +12,21 @@ __all__ = [ | |||||
| 'LoadBestModelCallback', | 'LoadBestModelCallback', | ||||
| "EarlyStopCallback", | "EarlyStopCallback", | ||||
| 'MoreEvaluateCallback', | |||||
| "TorchWarmupCallback", | "TorchWarmupCallback", | ||||
| "TorchGradClipCallback" | |||||
| "TorchGradClipCallback", | |||||
| ] | ] | ||||
| from .callback import Callback | from .callback import Callback | ||||
| from .callback_events import EventsList, Events, Filter | from .callback_events import EventsList, Events, Filter | ||||
| from .callback_manager import CallbackManager | from .callback_manager import CallbackManager | ||||
| from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||||
| from .checkpoint_callback import CheckpointCallback | |||||
| from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
| from .lr_scheduler_callback import LRSchedCallback | from .lr_scheduler_callback import LRSchedCallback | ||||
| from .load_best_model_callback import LoadBestModelCallback | from .load_best_model_callback import LoadBestModelCallback | ||||
| from .early_stop_callback import EarlyStopCallback | from .early_stop_callback import EarlyStopCallback | ||||
| from .torch_callbacks import * | from .torch_callbacks import * | ||||
| from .more_evaluate_callback import MoreEvaluateCallback | |||||
| @@ -236,7 +236,7 @@ class Callback: | |||||
| 结束 validate 时调用,并把 validate 的结果传入。 | 结束 validate 时调用,并把 validate 的结果传入。 | ||||
| :param trainer: | :param trainer: | ||||
| :param results: | |||||
| :param results: Evaluate 的结果,一般是个 dict 。 | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| pass | pass | ||||
| @@ -250,6 +250,15 @@ class Callback: | |||||
| """ | """ | ||||
| return self.__class__.__name__ | return self.__class__.__name__ | ||||
| @property | |||||
| def need_reproducible_sampler(self) -> bool: | |||||
| """ | |||||
| 当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。 | |||||
| :return: | |||||
| """ | |||||
| return False | |||||
| class _CallbackWrapper(Callback): | class _CallbackWrapper(Callback): | ||||
| """ | """ | ||||
| @@ -8,7 +8,6 @@ __all__ = [ | |||||
| from .callback_events import Events | from .callback_events import Events | ||||
| from .callback import Callback | from .callback import Callback | ||||
| from .checkpoint_callback import TrainerCheckpointCallback | |||||
| from .progress_callback import ProgressCallback, choose_progress_callback | from .progress_callback import ProgressCallback, choose_progress_callback | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| @@ -45,7 +44,7 @@ class CallbackManager: | |||||
| :param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | :param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | ||||
| """ | """ | ||||
| self._has_trainer_checkpoint = False | |||||
| self._need_reproducible_sampler = False | |||||
| _has_progress_callback = False | _has_progress_callback = False | ||||
| _callbacks = [] | _callbacks = [] | ||||
| @@ -98,8 +97,7 @@ class CallbackManager: | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| for each_callback in self.class_callbacks: | for each_callback in self.class_callbacks: | ||||
| if isinstance(each_callback, TrainerCheckpointCallback): | |||||
| self._has_trainer_checkpoint = True | |||||
| self._need_reproducible_sampler |= each_callback.need_reproducible_sampler | |||||
| self.dissect_one_callback(each_callback) | self.dissect_one_callback(each_callback) | ||||
| def dissect_one_callback(self, callback: Callback): | def dissect_one_callback(self, callback: Callback): | ||||
| @@ -211,7 +209,7 @@ class CallbackManager: | |||||
| @property | @property | ||||
| def has_trainer_checkpoint(self) -> bool: | def has_trainer_checkpoint(self) -> bool: | ||||
| return self._has_trainer_checkpoint | |||||
| return self._need_reproducible_sampler | |||||
| @_transfer | @_transfer | ||||
| def on_after_trainer_initialized(self, trainer): | def on_after_trainer_initialized(self, trainer): | ||||
| @@ -1,339 +1,151 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'ModelCheckpointCallback', | |||||
| 'TrainerCheckpointCallback' | |||||
| 'CheckpointCallback' | |||||
| ] | ] | ||||
| import os | |||||
| from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | |||||
| from typing import Union, Optional, Callable, Dict, Sequence | |||||
| from pathlib import Path | from pathlib import Path | ||||
| import sys | import sys | ||||
| from copy import deepcopy | |||||
| import fastNLP | |||||
| from .has_monitor_callback import HasMonitorCallback | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK | |||||
| from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||||
| from .topk_saver import TopkSaver | |||||
| from .callback import Callback | |||||
| class CheckpointCallback(HasMonitorCallback): | |||||
| def __init__( | |||||
| self, | |||||
| monitor:Optional[Union[str, Callable]]=None, | |||||
| save_folder: Optional[Union[str, Path]] = None, | |||||
| save_every_n_epochs: Optional[int] = None, | |||||
| save_every_n_batches: Optional[int] = None, | |||||
| save_last: bool = False, | |||||
| save_topk: Optional[int] = None, | |||||
| save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | |||||
| larger_better: bool = True, | |||||
| only_state_dict: bool = True, | |||||
| model_save_fn: Optional[Callable] = None, | |||||
| **kwargs, | |||||
| ): | |||||
| class CheckpointCallback(Callback): | |||||
| def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | |||||
| every_n_batches: Optional[int] = None, last: bool = False, | |||||
| on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0, | |||||
| monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | |||||
| only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | |||||
| save_evaluate_results=True, **kwargs): | |||||
| """ | """ | ||||
| 请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 | |||||
| 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
| - folder/ | |||||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
| - {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 | |||||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 | |||||
| - {save_object}-last/ # 最后一个 epoch 的保存 | |||||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
| model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||||
| 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
| :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
| :param save_every_n_batches: 多少个 batch 保存一次。 | |||||
| :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
| :param save_topk: 保存 monitor 结果 topK 个。 | |||||
| :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
| :param every_n_epochs: 多少个 epoch 保存一次。 | |||||
| :param every_n_batches: 多少个 batch 保存一次。 | |||||
| :param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
| :param topk: 保存 monitor 结果 topK 个。 | |||||
| :param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
| :param larger_better: monitor 的值是否时越大越好。 | :param larger_better: monitor 的值是否时越大越好。 | ||||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | ||||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | ||||
| 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | ||||
| :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||||
| :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
| fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
| :param kwargs: | :param kwargs: | ||||
| """ | """ | ||||
| super().__init__(monitor=monitor, larger_better=larger_better, | |||||
| must_have_monitor=save_topk is not None) | |||||
| if save_folder is None: | |||||
| super().__init__() | |||||
| if folder is None: | |||||
| logger.warning( | logger.warning( | ||||
| "Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||||
| save_folder = Path.cwd() | |||||
| save_folder = Path(save_folder) | |||||
| if not save_folder.exists(): | |||||
| raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | |||||
| elif save_folder.is_file(): | |||||
| raise ValueError("Parameter `save_folder` should be a directory instead of a file.") | |||||
| if save_every_n_epochs is not None: | |||||
| if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: | |||||
| raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") | |||||
| "Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||||
| folder = Path.cwd() | |||||
| folder = Path(folder) | |||||
| if not folder.exists(): | |||||
| raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") | |||||
| elif folder.is_file(): | |||||
| raise ValueError("Parameter `folder` should be a directory instead of a file.") | |||||
| if every_n_epochs is not None: | |||||
| if not isinstance(every_n_epochs, int) or every_n_epochs < 1: | |||||
| raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.") | |||||
| else: | else: | ||||
| save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||||
| every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||||
| if save_every_n_batches is not None: | |||||
| if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: | |||||
| raise ValueError( | |||||
| "parameter save_every_n_batches should be an int and greater than or equal to 1.") | |||||
| if every_n_batches is not None: | |||||
| if not isinstance(every_n_batches, int) or every_n_batches < 1: | |||||
| raise ValueError("Parameter `every_n_batches` should be an int and greater than or equal to 1.") | |||||
| else: | else: | ||||
| save_every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||||
| every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||||
| if save_topk is not None: | |||||
| if not isinstance(save_topk, int) or save_topk < 1: | |||||
| raise ValueError("parameter save_topk should be an int and greater than or equal to 1.") | |||||
| if topk is not None: | |||||
| if not isinstance(topk, int): | |||||
| raise ValueError("Parameter `topk` should be an int.") | |||||
| else: | |||||
| topk = 0 | |||||
| if save_on_exception is not None: | |||||
| if not isinstance(save_on_exception, Sequence): | |||||
| save_on_exception = [save_on_exception] | |||||
| if on_exceptions is not None: | |||||
| if not isinstance(on_exceptions, Sequence): | |||||
| on_exceptions = [on_exceptions] | |||||
| for exception in save_on_exception: | |||||
| for exception in on_exceptions: | |||||
| if not issubclass(exception, BaseException): | if not issubclass(exception, BaseException): | ||||
| raise TypeError("Each exception in parameter `save_on_exception` can only be " | |||||
| raise TypeError("Each exception in parameter `on_exception` can only be " | |||||
| "`BaseException` type.") | "`BaseException` type.") | ||||
| else: | else: | ||||
| save_on_exception = [] | |||||
| on_exceptions = [] | |||||
| self.save_folder = save_folder | |||||
| self.save_every_n_epochs = save_every_n_epochs | |||||
| self.save_every_n_batches = save_every_n_batches | |||||
| self.save_last = save_last | |||||
| self.save_topk = save_topk | |||||
| self.only_state_dict = only_state_dict | |||||
| self.model_save_fn = model_save_fn | |||||
| self.save_on_exception = save_on_exception | |||||
| self.kwargs = kwargs | |||||
| self.topk_saver = TopkSaver(topk, monitor, larger_better, folder, only_state_dict, | |||||
| model_save_fn, save_evaluate_results, | |||||
| save_object, **kwargs) | |||||
| self.topk = topk | |||||
| self.save_object = save_object | |||||
| # 这些参数是专门留给 topk 模式专门使用的; | |||||
| self._topk_model = {} | |||||
| self._topn = 0 # 表示目前已经保存了几个最好的模型; | |||||
| self.every_n_epochs = every_n_epochs | |||||
| self.every_n_batches = every_n_batches | |||||
| self.last = last | |||||
| self.exceptions = on_exceptions | |||||
| # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||||
| # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||||
| self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||||
| # 该 folder 只在保存真的要发生的时候再创建。 | |||||
| @property | |||||
| def need_reproducible_sampler(self) -> bool: | |||||
| return self.save_object == 'trainer' | |||||
| def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
| if self.save_topk is not None: | |||||
| super().on_after_trainer_initialized(trainer, driver) | |||||
| if self.save_topk is not None and trainer.evaluator is None: | |||||
| logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | |||||
| if self.topk_saver.topk_queue: # 需要设置 monitor | |||||
| if self.topk_saver.monitor is None: | |||||
| self.topk_saver.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
| if self.topk_saver.topk_queue and trainer.evaluator is None: | |||||
| logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") | |||||
| def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
| self._save_topk(trainer, results) | |||||
| # 如果发生了保存,则返回的 folder 不为 None | |||||
| folder = self.topk_saver.save_topk(trainer, results) | |||||
| def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | ||||
| if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | |||||
| folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' | |||||
| self.save(trainer, folder_name=folder_name) | |||||
| if self.save_last: | |||||
| folder_name = f'{self.folder_prefix}-last' | |||||
| self.save(trainer, folder_name=folder_name) | |||||
| if trainer.cur_epoch_idx % self.every_n_epochs == 0: | |||||
| folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}' | |||||
| self.topk_saver.save(trainer, folder_name=folder_name) | |||||
| if self.last: | |||||
| folder_name = f'{self.save_object}-last' | |||||
| self.topk_saver.save(trainer, folder_name=folder_name) | |||||
| def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
| if trainer.global_forward_batches % self.save_every_n_batches == 0: | |||||
| folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||||
| self.save(trainer, folder_name=folder_name) | |||||
| if trainer.global_forward_batches % self.every_n_batches == 0: | |||||
| folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||||
| self.topk_saver.save(trainer, folder_name=folder_name) | |||||
| def on_exception(self, trainer, exception: BaseException): | def on_exception(self, trainer, exception: BaseException): | ||||
| if exception.__class__ in self.save_on_exception: | |||||
| folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||||
| f'exception_{exception.__class__.__name__}' | |||||
| self.save(trainer=trainer, folder_name=folder_name) | |||||
| if exception.__class__ in self.exceptions: | |||||
| folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||||
| f'exception_{exception.__class__.__name__}' | |||||
| self.topk_saver.save(trainer, folder_name=folder_name) | |||||
| def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
| """ | """ | ||||
| 保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | |||||
| topk_model的状态 | |||||
| _real_monitor的值 | |||||
| 保存状态,以便之后可以继续使用 | |||||
| """ | """ | ||||
| states = {} | states = {} | ||||
| states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||||
| states['_topk_model'] = deepcopy(self._topk_model) | |||||
| states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||||
| if isinstance(self._real_monitor, str): | |||||
| states['_real_monitor'] = self._real_monitor | |||||
| states['topk_saver'] = self.topk_saver.state_dict() | |||||
| return states | return states | ||||
| def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
| timestamp_path = states['timestamp_path'] | |||||
| if not os.path.exists(timestamp_path): | |||||
| logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " | |||||
| f" {self.timestamp_path.absolute()}.") | |||||
| else: | |||||
| logger.info(f"Resume to checkpoint in path: {timestamp_path}.") | |||||
| self.timestamp_path = Path(timestamp_path) | |||||
| _topk_model = states['_topk_model'] | |||||
| save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | |||||
| if save_topk is not None and self.save_topk is not None: | |||||
| assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | |||||
| f"as {save_topk}." | |||||
| self._topk_model.update(self._topk_model) | |||||
| self._real_monitor = states["_real_monitor"] | |||||
| def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||||
| """ | |||||
| 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | |||||
| :param trainer: | |||||
| :param results: | |||||
| :return: | |||||
| """ | |||||
| if self.save_topk is not None: | |||||
| monitor_value = self.get_monitor_value(results=results) | |||||
| if monitor_value is None: | |||||
| return | |||||
| folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||||
| f"-{self._real_monitor}_{monitor_value}" | |||||
| _should_save = False | |||||
| if self._topn < self.save_topk: | |||||
| self._topk_model[folder_name] = monitor_value | |||||
| self._topn += 1 | |||||
| _should_save = True | |||||
| else: | |||||
| _least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||||
| key=lambda x: self._topk_model[x]) | |||||
| if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||||
| self._topk_model[folder_name] = monitor_value | |||||
| _should_save = True | |||||
| self._topk_model.pop(_least_valuable_model) | |||||
| synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | |||||
| assert len(self._topk_model) == self.save_topk == self._topn | |||||
| if _should_save: | |||||
| self.save(trainer, folder_name=folder_name) | |||||
| def save(self, trainer, folder_name): | |||||
| """ | |||||
| 执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 | |||||
| :param trainer: | |||||
| :param folder_name: | |||||
| :return: | |||||
| """ | |||||
| folder = self.timestamp_path.joinpath(folder_name) | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建 | |||||
| synchronize_mkdir(folder) | |||||
| _fn = getattr(trainer, self.save_fn_name) | |||||
| _fn( | |||||
| folder=folder, | |||||
| only_state_dict=self.only_state_dict, | |||||
| model_save_fn=self.model_save_fn, | |||||
| **self.kwargs | |||||
| ) | |||||
| @property | |||||
| def folder_prefix(self): | |||||
| raise NotImplementedError("The `folder_prefix` is not specified") | |||||
| @property | |||||
| def save_fn_name(self): | |||||
| raise NotImplementedError("The `save_fn_name` is not specified.") | |||||
| class ModelCheckpointCallback(CheckpointCallback): | |||||
| """ | |||||
| 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
| - save_folder/ | |||||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
| - model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||||
| - model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||||
| - model-last/ # 最后一个 epoch 的保存 | |||||
| - model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
| - model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
| topk_saver_states = states['topk_saver'] | |||||
| self.topk_saver.load_state_dict(topk_saver_states) | |||||
| model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||||
| 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
| :param save_every_n_batches: 多少个 batch 保存一次。 | |||||
| :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
| :param save_topk: 保存 monitor 结果 topK 个。 | |||||
| :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
| :param larger_better: monitor 的值是否时越大越好。 | |||||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
| 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
| :param kwargs: | |||||
| """ | |||||
| @property | |||||
| def save_fn_name(self): | |||||
| """ | |||||
| 调用 Trainer 中的哪个函数。 | |||||
| :return: | |||||
| """ | |||||
| return 'save_model' | |||||
| @property | |||||
| def callback_name(self): | |||||
| """ | |||||
| 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||||
| :return: | |||||
| """ | |||||
| return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
| @property | |||||
| def folder_prefix(self): | |||||
| return 'model' | |||||
| class TrainerCheckpointCallback(CheckpointCallback): | |||||
| """ | |||||
| 保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||||
| - save_folder/ | |||||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
| - trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||||
| - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||||
| - trainer-last/ # 最后一个 epoch 的保存 | |||||
| - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||||
| - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
| model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | |||||
| 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||||
| :param save_every_n_batches: 多少个 batch 保存一次。 | |||||
| :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||||
| :param save_topk: 保存 monitor 结果 topK 个。 | |||||
| :param save_on_exception: 在出异常信息时,是否保存。 | |||||
| :param larger_better: monitor 的值是否时越大越好。 | |||||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 | |||||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
| 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
| :param kwargs: | |||||
| """ | |||||
| @property | |||||
| def save_fn_name(self): | |||||
| """ | |||||
| 调用 Trainer 中的哪个函数。 | |||||
| :return: | |||||
| """ | |||||
| return 'save' | |||||
| @property | |||||
| def callback_name(self): | |||||
| """ | |||||
| 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||||
| :return: | |||||
| """ | |||||
| return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||||
| @property | |||||
| def folder_prefix(self): | |||||
| return 'trainer' | |||||
| @@ -1,10 +1,12 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'HasMonitorCallback', | 'HasMonitorCallback', | ||||
| 'ExecuteOnceBetterMonitor' | |||||
| 'ExecuteOnceBetterMonitor', | |||||
| 'MonitorUtility' | |||||
| ] | ] | ||||
| from typing import Dict, Union, Any | from typing import Dict, Union, Any | ||||
| from abc import ABC | from abc import ABC | ||||
| import functools | |||||
| from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
| from fastNLP.core.callbacks import Callback | from fastNLP.core.callbacks import Callback | ||||
| @@ -27,21 +29,13 @@ class CanItemDataType(ABC): | |||||
| return NotImplemented | return NotImplemented | ||||
| class MonitorUtility: | |||||
| """ | |||||
| 计算 monitor 的相关函数 | |||||
| class HasMonitorCallback(Callback): | |||||
| def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
| """ | |||||
| 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||||
| (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| :param larger_better: monitor 是否时越大越好 | |||||
| :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
| """ | |||||
| """ | |||||
| def __init__(self, monitor, larger_better): | |||||
| self.set_monitor(monitor, larger_better) | self.set_monitor(monitor, larger_better) | ||||
| self.must_have_moinitor = must_have_monitor | |||||
| def set_monitor(self, monitor, larger_better): | def set_monitor(self, monitor, larger_better): | ||||
| if callable(monitor): # 检查是否能够接受一个参数 | if callable(monitor): # 检查是否能够接受一个参数 | ||||
| @@ -57,26 +51,14 @@ class HasMonitorCallback(Callback): | |||||
| self.monitor_value = float('inf') | self.monitor_value = float('inf') | ||||
| self._real_monitor = self.monitor | self._real_monitor = self.monitor | ||||
| def on_after_trainer_initialized(self, trainer, driver): | |||||
| def itemize_results(self, results): | |||||
| """ | """ | ||||
| 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
| 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
| 将结果中有 .item() 方法的都调用一下,使得可以结果可以保存 | |||||
| :param trainer: | |||||
| :param driver: | |||||
| :param results: | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.monitor is None and trainer.monitor is not None: | |||||
| self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
| if self.must_have_moinitor and self.monitor is None: | |||||
| raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
| f"You can set it in the initialization or through Trainer.") | |||||
| def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
| # 主要核对一下 monitor 是否存在。 | |||||
| if self.monitor is not None: | |||||
| self.get_monitor_value(results=sanity_check_res) | |||||
| return apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
| def get_monitor_value(self, results:Dict)->Union[float, None]: | def get_monitor_value(self, results:Dict)->Union[float, None]: | ||||
| """ | """ | ||||
| @@ -85,10 +67,10 @@ class HasMonitorCallback(Callback): | |||||
| :param results: | :param results: | ||||
| :return: 如果为 None ,表明此次没有找到合适的monitor | :return: 如果为 None ,表明此次没有找到合适的monitor | ||||
| """ | """ | ||||
| if len(results)==0: | |||||
| if len(results) == 0 or self.monitor is None: | |||||
| return None | return None | ||||
| # 保证所有的 tensor 都被转换为了 python 特定的类型 | # 保证所有的 tensor 都被转换为了 python 特定的类型 | ||||
| results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
| results = self.itemize_results(results) | |||||
| use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | ||||
| real_monitor=self._real_monitor, | real_monitor=self._real_monitor, | ||||
| res=results) | res=results) | ||||
| @@ -97,7 +79,7 @@ class HasMonitorCallback(Callback): | |||||
| # 第一次运行 | # 第一次运行 | ||||
| if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | ||||
| logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | ||||
| f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
| f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||||
| # 检测到此次和上次不同。 | # 检测到此次和上次不同。 | ||||
| elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | ||||
| logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | ||||
| @@ -165,7 +147,10 @@ class HasMonitorCallback(Callback): | |||||
| """ | """ | ||||
| if callable(self.monitor): | if callable(self.monitor): | ||||
| try: | try: | ||||
| monitor_name = self.monitor.__qualname__ | |||||
| monitor = self.monitor | |||||
| while isinstance(monitor, functools.partial): | |||||
| monitor = monitor.func | |||||
| monitor_name = monitor.__qualname__ | |||||
| except: | except: | ||||
| monitor_name = self.monitor.__name__ | monitor_name = self.monitor.__name__ | ||||
| elif self.monitor is None: | elif self.monitor is None: | ||||
| @@ -176,6 +161,46 @@ class HasMonitorCallback(Callback): | |||||
| return monitor_name | return monitor_name | ||||
| class HasMonitorCallback(MonitorUtility, Callback): | |||||
| def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
| """ | |||||
| 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||||
| (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
| :param larger_better: monitor 是否时越大越好 | |||||
| :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||||
| """ | |||||
| super().__init__(monitor, larger_better) | |||||
| self.must_have_monitor = must_have_monitor | |||||
| def on_after_trainer_initialized(self, trainer, driver): | |||||
| """ | |||||
| 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
| 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
| :param trainer: | |||||
| :param driver: | |||||
| :return: | |||||
| """ | |||||
| if self.monitor is None and trainer.monitor is not None: | |||||
| self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
| if self.must_have_monitor and self.monitor is None: | |||||
| raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
| f"You can set it in the initialization or through Trainer.") | |||||
| if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None: | |||||
| raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.__class__.__name__}" | |||||
| f" need to watch the monitor:`{self.monitor_name}`.") | |||||
| def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
| # 主要核对一下 monitor 是否存在。 | |||||
| if self.monitor is not None: | |||||
| self.get_monitor_value(results=sanity_check_res) | |||||
| class ExecuteOnceBetterMonitor(HasMonitorCallback): | class ExecuteOnceBetterMonitor(HasMonitorCallback): | ||||
| def __init__(self, monitor, larger_better, execute_fn): | def __init__(self, monitor, larger_better, execute_fn): | ||||
| """ | """ | ||||
| @@ -183,13 +208,13 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
| :param larger_better: monitor 是否时越大越好 | :param larger_better: monitor 是否时越大越好 | ||||
| :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | ||||
| """ | """ | ||||
| super().__init__(monitor, larger_better, must_have_monitor=True) | super().__init__(monitor, larger_better, must_have_monitor=True) | ||||
| _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | ||||
| self.execute_fn = execute_fn() | |||||
| self.execute_fn = execute_fn | |||||
| def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
| if self.is_better_results(results): | if self.is_better_results(results): | ||||
| @@ -23,7 +23,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
| :param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
| :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
| 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
| @@ -72,7 +72,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | ||||
| except NotImplementedError: | except NotImplementedError: | ||||
| raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | ||||
| f"save best model when launch using script.") | |||||
| f"save best model when launch using module.") | |||||
| super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
| @@ -87,7 +87,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
| def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
| logger.info(f"Loading best model with {self._real_monitor}: {self.monitor_value}...") | |||||
| logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") | |||||
| if self.real_save_folder: | if self.real_save_folder: | ||||
| trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| model_load_fn=self.model_load_fn) | model_load_fn=self.model_load_fn) | ||||
| @@ -0,0 +1,174 @@ | |||||
| __all__ = [ | |||||
| 'MoreEvaluateCallback' | |||||
| ] | |||||
| import os | |||||
| from typing import Union, Callable, Optional, Dict | |||||
| from fastNLP.core.log import logger | |||||
| from .has_monitor_callback import HasMonitorCallback | |||||
| from .topk_saver import TopkSaver | |||||
| class MoreEvaluateCallback(HasMonitorCallback): | |||||
| def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | |||||
| watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | |||||
| evaluate_fn=None, num_eval_sanity_batch=2, | |||||
| topk=0, topk_monitor=None, topk_larger_better=True, | |||||
| folder=None, only_state_dict=True, save_object='model', model_save_fn=None, | |||||
| save_evaluate_results=True, save_kwargs=None, | |||||
| **kwargs): | |||||
| """ | |||||
| 当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||||
| 一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||||
| 无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||||
| topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||||
| 如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存 | |||||
| - folder/ | |||||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||||
| :param dataloaders: 需要评估的数据 | |||||
| :param metrics: 使用的 metrics 。 | |||||
| :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch | |||||
| evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||||
| 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | |||||
| :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | |||||
| 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | |||||
| 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 | |||||
| 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | |||||
| 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||||
| :param watch_monitor_larger_better: watch_monitor 是否越大越好。 | |||||
| :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
| `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
| 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
| :param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 | |||||
| :param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 | |||||
| evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 | |||||
| :param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | |||||
| :param topk_larger_better: topk_monitor 的值是否时越大越好。 | |||||
| :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||||
| :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||||
| 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||||
| :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||||
| fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||||
| :param save_kwargs: dict。更多的保存相关的参数。 | |||||
| :param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。请特别留意 evaluate_fn 的设置。 | |||||
| """ | |||||
| super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, | |||||
| must_have_monitor=False) | |||||
| if watch_monitor is None and evaluate_every is None: | |||||
| raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | |||||
| if watch_monitor is not None and evaluate_every is not None: | |||||
| raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | |||||
| self.watch_monitor = watch_monitor | |||||
| if topk_monitor is not None and topk == 0: | |||||
| raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | |||||
| if topk != 0 and topk_monitor is None: | |||||
| raise RuntimeError("`topk` is set, but `topk_monitor` is None.") | |||||
| assert save_object in ['trainer', 'model'] | |||||
| self.dataloaders = dataloaders | |||||
| self.metrics = metrics | |||||
| self.evaluate_every = evaluate_every | |||||
| self.evaluate_fn = evaluate_fn | |||||
| self.num_eval_sanity_batch = num_eval_sanity_batch | |||||
| if save_kwargs is None: | |||||
| save_kwargs = {} | |||||
| self.topk_saver = TopkSaver(topk=topk, monitor=topk_monitor, larger_better=topk_larger_better, | |||||
| folder=folder, only_state_dict=only_state_dict, | |||||
| model_save_fn=model_save_fn, save_evaluate_results=save_evaluate_results, | |||||
| save_object=save_object, **save_kwargs) | |||||
| self.kwargs = kwargs | |||||
| @property | |||||
| def need_reproducible_sampler(self) -> bool: | |||||
| return self.topk_saver.save_object == 'trainer' | |||||
| def on_after_trainer_initialized(self, trainer, driver): | |||||
| # 如果是需要 watch 的,不能没有 evaluator | |||||
| if self.watch_monitor is not None: | |||||
| assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \ | |||||
| f"evaluate_dataloaders is provided in Trainer." | |||||
| if trainer.evaluate_fn is self.evaluate_fn: | |||||
| logger.warning_once("The `evaluate_fn` is the same as in Trainer, there seems no need to use " | |||||
| "`MoreEvaluateCallback`.") | |||||
| # 初始化 evaluator , 同时避免调用 super 对 monitor 赋值 | |||||
| kwargs = { | |||||
| 'model': self.kwargs.get('model', trainer.model), | |||||
| 'dataloaders': self.dataloaders, | |||||
| 'metrics': self.metrics, | |||||
| 'driver': self.kwargs.get('driver', trainer.driver), | |||||
| 'device': self.kwargs.get('device', trainer.device), | |||||
| 'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn), | |||||
| 'evaluate_fn': self.evaluate_fn, | |||||
| 'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), | |||||
| 'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), | |||||
| 'fp16': self.kwargs.get('fp16', trainer.fp16), | |||||
| 'use_dist_sampler': self.kwargs.get('use_dist_sampler', | |||||
| trainer.kwargs.get('eval_use_dist_sampler', None)), | |||||
| 'progress_bar': self.kwargs.get('progress_bar', trainer.kwargs.get('progress_bar', 'auto')), | |||||
| 'verbose': self.kwargs.get('verbose', 1) | |||||
| } | |||||
| for key, value in self.kwargs.items(): | |||||
| if key not in kwargs: | |||||
| kwargs[key] = value | |||||
| from fastNLP.core.controllers.evaluator import Evaluator | |||||
| self.evaluator = Evaluator(**kwargs) | |||||
| if self.num_eval_sanity_batch>0: | |||||
| results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) | |||||
| self.topk_saver.get_monitor_value(results) | |||||
| def on_validate_end(self, trainer, results): | |||||
| if self.is_better_results(results, keep_if_better=True): | |||||
| results = self.evaluator.run() | |||||
| self.topk_saver.save_topk(trainer, results) | |||||
| def on_train_epoch_end(self, trainer): | |||||
| if self.watch_monitor is not None: | |||||
| return | |||||
| if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | |||||
| validate_every = -self.evaluate_every | |||||
| if trainer.cur_epoch_idx % validate_every == 0: | |||||
| results = self.evaluator.run() | |||||
| self.topk_saver.save_topk(trainer, results) | |||||
| def on_train_batch_end(self, trainer): | |||||
| if self.watch_monitor is not None: | |||||
| return | |||||
| if callable(self.evaluate_every): | |||||
| if self.evaluate_every(self): | |||||
| results = self.evaluator.run() | |||||
| self.topk_saver.save_topk(trainer, results) | |||||
| elif self.evaluate_every > 0 and trainer.global_forward_batches % self.evaluate_every == 0: | |||||
| results = self.evaluator.run() | |||||
| self.topk_saver.save_topk(trainer, results) | |||||
| def on_save_checkpoint(self, trainer) -> Dict: | |||||
| states = {'topk_saver': self.topk_saver.state_dict()} | |||||
| if isinstance(self._real_monitor, str): | |||||
| states['_real_monitor'] = self._real_monitor | |||||
| states['monitor_value'] = self.monitor_value | |||||
| return states | |||||
| def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||||
| topk_saver_states = states['topk_saver'] | |||||
| self.topk_saver.load_state_dict(topk_saver_states) | |||||
| if '_real_monitor' in states: | |||||
| self._real_monitor = states["_real_monitor"] | |||||
| self.monitor_value = states['monitor_value'] | |||||
| @property | |||||
| def callback_name(self): | |||||
| metric_names = '+'.join(sorted(self.metrics.keys())) | |||||
| return f'more_evaluate_callback#metric_name-{metric_names}#monitor-{self.monitor_name}#topk_saver:{self.topk_saver}' | |||||
| @@ -9,7 +9,6 @@ __all__ = [ | |||||
| ] | ] | ||||
| from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
| from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
| from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| @@ -42,7 +41,8 @@ class RichCallback(ProgressCallback): | |||||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | ||||
| 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | ||||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
| 相关的 monitor 值请返回 None 。 | |||||
| :param larger_better: 是否是 monitor 的结果越大越好。 | :param larger_better: 是否是 monitor 的结果越大越好。 | ||||
| :param format_json: 是否格式化 json 再打印 | :param format_json: 是否格式化 json 再打印 | ||||
| """ | """ | ||||
| @@ -135,7 +135,8 @@ class RawTextCallback(ProgressCallback): | |||||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | ||||
| 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | ||||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
| 相关的 monitor 值请返回 None 。 | |||||
| :param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
| :param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
| """ | """ | ||||
| @@ -0,0 +1,246 @@ | |||||
| import json | |||||
| import os | |||||
| from copy import deepcopy | |||||
| from pathlib import Path | |||||
| from typing import Optional, Dict, Tuple | |||||
| from fastNLP.core.utils import rank_zero_rm | |||||
| from fastNLP.core.log import logger | |||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||||
| from fastNLP.envs import rank_zero_call | |||||
| from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME | |||||
| from .has_monitor_callback import MonitorUtility | |||||
| class Saver: | |||||
| def __init__(self, folder, only_state_dict, model_save_fn, **kwargs): | |||||
| """ | |||||
| 执行保存的对象。保存的文件组织结构为 | |||||
| - folder # 当前初始化的参数 | |||||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||||
| - folder_name # 由 save() 调用时传入。 | |||||
| :param folder: | |||||
| :param only_state_dict: | |||||
| :param model_save_fn: | |||||
| :param kwargs: | |||||
| """ | |||||
| if folder is None: | |||||
| logger.warning( | |||||
| "Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||||
| folder = Path.cwd() | |||||
| folder = Path(folder) | |||||
| if not folder.exists(): | |||||
| raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") | |||||
| elif folder.is_file(): | |||||
| raise ValueError("Parameter `folder` should be a directory instead of a file.") | |||||
| self.folder = folder | |||||
| self.only_state_dict = only_state_dict | |||||
| self.model_save_fn = model_save_fn | |||||
| self.kwargs = kwargs | |||||
| self.eval_results = kwargs.get('eval_results', True) | |||||
| self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||||
| @rank_zero_call | |||||
| def save(self, save_fn, folder_name): | |||||
| """ | |||||
| 执行保存的函数,将数据保存在 folder/timestamp/folder_name 下。其中 folder 为用户在初始化指定, | |||||
| timestamp 为当前脚本的启动时间。 | |||||
| :param save_fn: 调用的保存函数,应该可接受参数 folder:str, only_state_dict: bool, model_save_fn: callable, kwargs | |||||
| :param folder_name: 保存的 folder 名称,将被创建。 | |||||
| :return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。 | |||||
| """ | |||||
| folder = self.timestamp_path.joinpath(folder_name) | |||||
| folder.mkdir(parents=True, exist_ok=True) | |||||
| save_fn( | |||||
| folder=folder, | |||||
| only_state_dict=self.only_state_dict, | |||||
| model_save_fn=self.model_save_fn, | |||||
| **self.kwargs | |||||
| ) | |||||
| return str(os.path.abspath(folder)) | |||||
| @rank_zero_call | |||||
| def save_json(self, results, path): | |||||
| """ | |||||
| 以 json 格式保存 results 到 path 中 | |||||
| :param results: | |||||
| :param path: | |||||
| :return: | |||||
| """ | |||||
| with open(path, 'w', encoding='utf8') as f: | |||||
| json.dump(results, f, indent=2) | |||||
| @rank_zero_call | |||||
| def rm(self, folder_name): | |||||
| """ | |||||
| 移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。 | |||||
| :param folder_name: | |||||
| :return: | |||||
| """ | |||||
| folder = self.timestamp_path.joinpath(folder_name) | |||||
| rank_zero_rm(folder) | |||||
| def state_dict(self): | |||||
| states = { | |||||
| 'timestamp_path': str(self.timestamp_path), | |||||
| } | |||||
| return states | |||||
| def load_state_dict(self, states): | |||||
| timestamp_path = states['timestamp_path'] | |||||
| if not os.path.exists(timestamp_path): | |||||
| logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to " | |||||
| f" {self.timestamp_path.absolute()}.") | |||||
| else: | |||||
| logger.info(f"Resume to save checkpoint in path: {timestamp_path}.") | |||||
| self.timestamp_path = Path(timestamp_path) | |||||
| def __str__(self): | |||||
| return 'saver' # saver是无状态的,不需要有特定名字 | |||||
| class TopkQueue: | |||||
| def __init__(self, topk): | |||||
| """ | |||||
| 用于维护处于 topk 的 key, value 对。 | |||||
| :param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。 | |||||
| """ | |||||
| assert isinstance(topk, int) | |||||
| self.topk = topk | |||||
| self.topk_dict = {} # 其中 key 为保存的 | |||||
| def push(self, key, value) -> Optional[Tuple[str, float]]: | |||||
| """ | |||||
| 将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给 | |||||
| 挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 | |||||
| 推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。 | |||||
| :param str key: | |||||
| :param float value: 如果为 None, 则不做任何操作。 | |||||
| :return: (1)返回输入的 (key, value) ,说明不满足 topk; (2) 返回(None, None),说明满足 topk 且没有被挤出过去的记录; (3) | |||||
| 返回非输入的 (key, value) , 说明输入满足 topk,且挤出了之前的记录。 | |||||
| """ | |||||
| if value is None: | |||||
| return key, value | |||||
| if self.topk < 0: | |||||
| return None, None | |||||
| if self.topk == 0: | |||||
| return key, value | |||||
| if len(self.topk_dict)<self.topk: | |||||
| self.topk_dict[key] = value | |||||
| return None, None | |||||
| min_key = min(self.topk_dict, key=lambda x:self.topk_dict[x]) | |||||
| if self.topk_dict[min_key] > value: | |||||
| return key, value | |||||
| else: | |||||
| min_value = self.topk_dict.pop(min_key) | |||||
| self.topk_dict[key] = value | |||||
| return min_key, min_value | |||||
| def state_dict(self): | |||||
| return deepcopy(self.topk_dict) | |||||
| def load_state_dict(self, states): | |||||
| self.topk_dict.update(states) | |||||
| def __str__(self): | |||||
| return f'topk-{self.topk}' | |||||
| def __bool__(self): | |||||
| # 仅当 topk 为 0 时,表明该 topk_queue 无意义。 | |||||
| return self.topk != 0 | |||||
| class TopkSaver(MonitorUtility, Saver): | |||||
| def __init__(self, topk, monitor, larger_better, folder, only_state_dict, | |||||
| model_save_fn, save_evaluate_results, | |||||
| save_object, **kwargs): | |||||
| """ | |||||
| 用来保存识别 tokp 模型并保存。 | |||||
| :param topk: | |||||
| :param monitor: | |||||
| :param larger_better: | |||||
| :param folder: | |||||
| :param only_state_dict: | |||||
| :param model_save_fn: | |||||
| :param save_evaluate_results: | |||||
| :param save_object: | |||||
| :param kwargs: | |||||
| """ | |||||
| MonitorUtility.__init__(self, monitor, larger_better) | |||||
| Saver.__init__(self, folder, only_state_dict, model_save_fn, **kwargs) | |||||
| if monitor is not None and topk == 0: | |||||
| raise RuntimeError("`monitor` is set, but `topk` is 0.") | |||||
| if topk != 0 and monitor is None: | |||||
| raise RuntimeError("`topk` is set, but `monitor` is None.") | |||||
| assert save_object in ['trainer', 'model'] | |||||
| self.saver = Saver(folder, only_state_dict, model_save_fn, **kwargs) | |||||
| self.topk_queue = TopkQueue(topk) | |||||
| self.save_evaluate_results = save_evaluate_results | |||||
| self.save_object = save_object | |||||
| self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' | |||||
| @rank_zero_call | |||||
| def save_topk(self, trainer, results: Dict) -> Optional[str]: | |||||
| """ | |||||
| 根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。 | |||||
| :param trainer: | |||||
| :param results: | |||||
| :return: | |||||
| """ | |||||
| if self.monitor is not None and self.topk_queue: | |||||
| monitor_value = self.get_monitor_value(results) | |||||
| if monitor_value is None: | |||||
| return | |||||
| key = f"{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||||
| f"-{self.monitor_name}_{monitor_value}" | |||||
| pop_key, pop_value = self.topk_queue.push(key, monitor_value if self.larger_better else -monitor_value) | |||||
| if pop_key == key: # 说明不足以构成 topk,被退回了 | |||||
| return None | |||||
| folder = self.save(trainer, key) | |||||
| if self.save_evaluate_results and folder: | |||||
| try: | |||||
| self.save_json(self.itemize_results(results), | |||||
| os.path.join(folder, FASTNLP_EVALUATE_RESULT_FILENAME)) | |||||
| except: | |||||
| logger.exception(f"Fail to save evaluate results to {folder}") | |||||
| if pop_key and pop_key != key: # 说明需要移除之前的 topk | |||||
| self.rm(pop_key) | |||||
| return folder | |||||
| def save(self, trainer, folder_name): | |||||
| fn = getattr(trainer, self.save_fn_name) | |||||
| return super().save(fn, folder_name) | |||||
| def state_dict(self): | |||||
| states = { | |||||
| 'topk_queue': self.topk_queue.state_dict(), | |||||
| 'saver': self.saver.state_dict() | |||||
| } | |||||
| if isinstance(self._real_monitor, str): | |||||
| states['_real_monitor'] = self._real_monitor | |||||
| return states | |||||
| def load_state_dict(self, states): | |||||
| topk_queue_states = states['topk_queue'] | |||||
| saver_states = states['saver'] | |||||
| self.topk_queue.load_state_dict(topk_queue_states) | |||||
| self.saver.load_state_dict(saver_states) | |||||
| if '_real_monitor' in states: | |||||
| self._real_monitor = states["_real_monitor"] | |||||
| def __str__(self): | |||||
| return f'topk-{self.topk_queue}#saver-{self.saver}#save_object-{self.save_object}' | |||||
| @@ -1,4 +1,6 @@ | |||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| import os | |||||
| from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
| from difflib import SequenceMatcher | from difflib import SequenceMatcher | ||||
| from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
| @@ -15,7 +17,7 @@ def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str | |||||
| :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | :return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | ||||
| 找到对应的 monitor | 找到对应的 monitor | ||||
| """ | """ | ||||
| if len(res)==0: | |||||
| if len(res) == 0 or monitor is None: | |||||
| return monitor, None | return monitor, None | ||||
| if callable(monitor): | if callable(monitor): | ||||
| @@ -56,4 +58,3 @@ def _match_length(a:str, b:str)->int: | |||||
| match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | ||||
| return match.size | return match.size | ||||
| @@ -38,7 +38,7 @@ class Evaluator: | |||||
| driver: Union[str, Driver] = 'torch', | driver: Union[str, Driver] = 'torch', | ||||
| device: Optional[Union[int, List[int], str]] = None, | device: Optional[Union[int, List[int], str]] = None, | ||||
| batch_step_fn: Optional[callable] = None, | batch_step_fn: Optional[callable] = None, | ||||
| evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | |||||
| evaluate_fn: Optional[str] = None, | |||||
| input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
| output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
| model_wo_auto_param_call: bool = False, | model_wo_auto_param_call: bool = False, | ||||
| @@ -57,8 +57,9 @@ class Evaluator: | |||||
| :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | ||||
| DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | ||||
| batch_step_fn 函数。 | batch_step_fn 函数。 | ||||
| :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`; | |||||
| 默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||||
| :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||||
| `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||||
| 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||||
| :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | ||||
| :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | ||||
| :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
| @@ -69,6 +70,7 @@ class Evaluator: | |||||
| :param kwargs: | :param kwargs: | ||||
| bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | ||||
| 与 batch normalization 将会关闭。默认为True。 | 与 batch normalization 将会关闭。默认为True。 | ||||
| TODO 还没完成。 | |||||
| Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | ||||
| tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | ||||
| 当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | 当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | ||||
| @@ -119,10 +121,6 @@ class Evaluator: | |||||
| self._metric_wrapper = None | self._metric_wrapper = None | ||||
| _ = self.metrics_wrapper # 触发检查 | _ = self.metrics_wrapper # 触发检查 | ||||
| if self._dist_sampler is not None and not self.driver.is_distributed(): | |||||
| logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause " | |||||
| "different process evaluating on different data.") | |||||
| if evaluate_fn is not None and not isinstance(evaluate_fn, str): | if evaluate_fn is not None and not isinstance(evaluate_fn, str): | ||||
| raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | ||||
| self._evaluate_step, self._evaluate_step_signature_fn = \ | self._evaluate_step, self._evaluate_step_signature_fn = \ | ||||
| @@ -14,10 +14,10 @@ __all__ = [ | |||||
| from .loops import Loop, TrainBatchLoop | from .loops import Loop, TrainBatchLoop | ||||
| from .utils import State, TrainerState | from .utils import State, TrainerState | ||||
| from .utils.utils import check_validate_every | |||||
| from .utils.utils import check_evaluate_every | |||||
| from .evaluator import Evaluator | from .evaluator import Evaluator | ||||
| from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | ||||
| from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | |||||
| from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | |||||
| from fastNLP.core.callbacks.callback import _CallbackWrapper | from fastNLP.core.callbacks.callback import _CallbackWrapper | ||||
| from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
| from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
| @@ -26,7 +26,7 @@ from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nu | |||||
| from fastNLP.core.utils.utils import _check_valid_parameters_number | from fastNLP.core.utils.utils import _check_valid_parameters_number | ||||
| from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||||
| from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||||
| from fastNLP.core.utils.exceptions import EarlyStopException | from fastNLP.core.utils.exceptions import EarlyStopException | ||||
| @@ -94,9 +94,9 @@ class Trainer(TrainerEventTrigger): | |||||
| evaluate_step 这个函数,如果没有则使用 forward 函数。 | evaluate_step 这个函数,如果没有则使用 forward 函数。 | ||||
| :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | ||||
| :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | ||||
| :param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||||
| 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||||
| 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||||
| :param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次; | |||||
| 为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||||
| 返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | |||||
| :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | ||||
| 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | ||||
| value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | ||||
| @@ -124,7 +124,7 @@ class Trainer(TrainerEventTrigger): | |||||
| set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | ||||
| use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | ||||
| 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | ||||
| use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
| eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
| output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | ||||
| ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
| log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
| @@ -214,13 +214,13 @@ class Trainer(TrainerEventTrigger): | |||||
| """ 设置内部的 Evaluator """ | """ 设置内部的 Evaluator """ | ||||
| if metrics is None and evaluate_dataloaders is not None: | if metrics is None and evaluate_dataloaders is not None: | ||||
| raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") | |||||
| raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||||
| if metrics is not None and evaluate_dataloaders is None: | if metrics is not None and evaluate_dataloaders is None: | ||||
| raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") | |||||
| raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | |||||
| self.metrics = metrics | self.metrics = metrics | ||||
| self.validate_every = evaluate_every | |||||
| self.evaluate_every = evaluate_every | |||||
| self.driver.setup() | self.driver.setup() | ||||
| self.driver.barrier() | self.driver.barrier() | ||||
| @@ -235,7 +235,7 @@ class Trainer(TrainerEventTrigger): | |||||
| self.monitor = monitor | self.monitor = monitor | ||||
| self.larger_better = larger_better | self.larger_better = larger_better | ||||
| if metrics is not None and evaluate_dataloaders is not None: | if metrics is not None and evaluate_dataloaders is not None: | ||||
| check_validate_every(evaluate_every) | |||||
| check_evaluate_every(evaluate_every) | |||||
| self.evaluator = Evaluator( | self.evaluator = Evaluator( | ||||
| model=model, | model=model, | ||||
| dataloaders=evaluate_dataloaders, | dataloaders=evaluate_dataloaders, | ||||
| @@ -248,7 +248,7 @@ class Trainer(TrainerEventTrigger): | |||||
| output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
| fp16=fp16, | fp16=fp16, | ||||
| verbose=0, | verbose=0, | ||||
| use_dist_sampler=kwargs.get("use_eval_dist_sampler", None), | |||||
| use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), | |||||
| progress_bar=kwargs.get('progress_bar', 'auto') | progress_bar=kwargs.get('progress_bar', 'auto') | ||||
| ) | ) | ||||
| @@ -261,11 +261,14 @@ class Trainer(TrainerEventTrigger): | |||||
| self.driver.set_deterministic_dataloader(self.dataloader) | self.driver.set_deterministic_dataloader(self.dataloader) | ||||
| self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | ||||
| reproducible=self.callback_manager.has_trainer_checkpoint) | |||||
| reproducible=self.callback_manager._need_reproducible_sampler) | |||||
| self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | ||||
| self.on_after_trainer_initialized(self.driver) | |||||
| self.evaluate_batch_step_fn = evaluate_batch_step_fn | |||||
| self.kwargs = kwargs | |||||
| self.on_after_trainer_initialized(self.driver) | |||||
| self.driver.barrier() | self.driver.barrier() | ||||
| def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | ||||
| @@ -364,10 +367,10 @@ class Trainer(TrainerEventTrigger): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.evaluator is not None: | if self.evaluator is not None: | ||||
| if callable(self.validate_every): | |||||
| if self.validate_every(self): | |||||
| if callable(self.evaluate_every): | |||||
| if self.evaluate_every(self): | |||||
| self.run_evaluate() | self.run_evaluate() | ||||
| elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||||
| elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: | |||||
| self.run_evaluate() | self.run_evaluate() | ||||
| def epoch_validate(self): | def epoch_validate(self): | ||||
| @@ -377,8 +380,8 @@ class Trainer(TrainerEventTrigger): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.evaluator is not None: | if self.evaluator is not None: | ||||
| if isinstance(self.validate_every, int) and self.validate_every < 0: | |||||
| validate_every = -self.validate_every | |||||
| if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | |||||
| validate_every = -self.evaluate_every | |||||
| if self.cur_epoch_idx % validate_every == 0: | if self.cur_epoch_idx % validate_every == 0: | ||||
| self.run_evaluate() | self.run_evaluate() | ||||
| @@ -427,7 +430,7 @@ class Trainer(TrainerEventTrigger): | |||||
| self._custom_callbacks[None] = [] | self._custom_callbacks[None] = [] | ||||
| if self.marker is not None: | if self.marker is not None: | ||||
| if len(self._custom_callbacks[self.marker]) == 0: | if len(self._custom_callbacks[self.marker]) == 0: | ||||
| print(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " | |||||
| logger.info(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " | |||||
| f"`{self.marker}` that is added through function `Trainer.on`") | f"`{self.marker}` that is added through function `Trainer.on`") | ||||
| _own_callbacks += self._custom_callbacks[self.marker] | _own_callbacks += self._custom_callbacks[self.marker] | ||||
| for each_callback in _own_callbacks: | for each_callback in _own_callbacks: | ||||
| @@ -528,10 +531,10 @@ class Trainer(TrainerEventTrigger): | |||||
| r""" | r""" | ||||
| 用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现; | 用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现; | ||||
| :param folder: 保存模型的地址; | |||||
| :param only_state_dict: 是否只保存模型的 `state_dict`; | |||||
| :param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则在这个文件夹下创建 fastnlp_model.pkl.tar 文件。 | |||||
| :param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 `state_dict`; | |||||
| :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | ||||
| :param kwargs: 一些 driver 的保存模型的函数的参数另有其它; | |||||
| :param kwargs: | |||||
| """ | """ | ||||
| self.on_save_model() | self.on_save_model() | ||||
| @@ -568,14 +571,19 @@ class Trainer(TrainerEventTrigger): | |||||
| self.on_load_model() | self.on_load_model() | ||||
| self.driver.barrier() | self.driver.barrier() | ||||
| if not isinstance(folder, (io.BytesIO, BinaryIO)): | if not isinstance(folder, (io.BytesIO, BinaryIO)): | ||||
| if model_load_fn is not None: | |||||
| if not callable(model_load_fn): | |||||
| raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
| rank_zero_call(model_load_fn)(folder) | |||||
| else: | |||||
| if isinstance(folder, str): | |||||
| folder = Path(folder) | |||||
| self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
| try: | |||||
| if model_load_fn is not None: | |||||
| if not callable(model_load_fn): | |||||
| raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
| rank_zero_call(model_load_fn)(folder) | |||||
| else: | |||||
| if isinstance(folder, str): | |||||
| folder = Path(folder) | |||||
| self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
| except FileNotFoundError as e: | |||||
| if FASTNLP_MODEL_FILENAME not in os.listdir(folder): | |||||
| logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.") | |||||
| raise e | |||||
| else: | else: | ||||
| if model_load_fn is not None: | if model_load_fn is not None: | ||||
| raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | ||||
| @@ -585,11 +593,13 @@ class Trainer(TrainerEventTrigger): | |||||
| def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | ||||
| r""" | r""" | ||||
| 用于断点重训 Trainer 的保存函数; | |||||
| 用于断点重训 Trainer 的保存函数。 | |||||
| :param folder: | |||||
| :param only_state_dict: | |||||
| :param model_save_fn: | |||||
| :param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 | |||||
| 如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件。 | |||||
| :param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重。 | |||||
| :param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder | |||||
| 参数),不必返回任何东西。 | |||||
| :param kwargs: | :param kwargs: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| @@ -602,17 +612,6 @@ class Trainer(TrainerEventTrigger): | |||||
| 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | ||||
| } | } | ||||
| # 3. validate filter state; | |||||
| if self.evaluator is not None: | |||||
| val_filter_state = {} | |||||
| if hasattr(self.step_validate, "__fastNLP_filter__"): | |||||
| val_filter_state["step_validate"] = self.step_validate.__fastNLP_filter__.state_dict() | |||||
| if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||||
| val_filter_state["epoch_validate"] = self.epoch_validate.__fastNLP_filter__.state_dict() | |||||
| states["val_filter_state"] = val_filter_state | |||||
| else: | |||||
| states["val_filter_state"] = None | |||||
| if isinstance(folder, str): | if isinstance(folder, str): | ||||
| folder = Path(folder) | folder = Path(folder) | ||||
| @@ -649,32 +648,30 @@ class Trainer(TrainerEventTrigger): | |||||
| dataloader = self.dataloader | dataloader = self.dataloader | ||||
| if not resume_training: | if not resume_training: | ||||
| dataloader = None | dataloader = None | ||||
| if model_load_fn is not None: | |||||
| if not callable(model_load_fn): | |||||
| raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||||
| rank_zero_call(model_load_fn)(folder) | |||||
| states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||||
| else: | |||||
| states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||||
| try: | |||||
| if model_load_fn is not None: | |||||
| if not callable(model_load_fn): | |||||
| raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||||
| rank_zero_call(model_load_fn)(folder) | |||||
| states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||||
| else: | |||||
| states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||||
| except FileNotFoundError as e: | |||||
| if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder): | |||||
| logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.") | |||||
| elif FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder): | |||||
| logger.error(f"fastNLP Trainer checkpoint file:{FASTNLP_CHECKPOINT_FILENAME} is not found in {folder}.") | |||||
| raise e | |||||
| if not resume_training: | if not resume_training: | ||||
| return | return | ||||
| self.dataloader = states.pop('dataloader') | self.dataloader = states.pop('dataloader') | ||||
| # 2. validate filter state; | |||||
| if self.evaluator is not None: | |||||
| val_filter_state = states["val_filter_state"] | |||||
| if hasattr(self.step_validate, "__fastNLP_filter__"): | |||||
| self.step_validate.__fastNLP_filter__.load_state_dict(val_filter_state["step_validate"]) | |||||
| if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||||
| self.epoch_validate.__fastNLP_filter__.load_state_dict(val_filter_state["epoch_validate"]) | |||||
| # 3. 恢复 trainer_state 的状态; | |||||
| # 1. 恢复 trainer_state 的状态; | |||||
| self.trainer_state.load_state_dict(states["trainer_state"]) | self.trainer_state.load_state_dict(states["trainer_state"]) | ||||
| # 4. 修改 trainer_state.batch_idx_in_epoch | |||||
| # 2. 修改 trainer_state.batch_idx_in_epoch | |||||
| # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | ||||
| # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | ||||
| # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | ||||
| @@ -126,7 +126,7 @@ class _TruncatedDataLoader: | |||||
| return getattr(self.dataloader, item) | return getattr(self.dataloader, item) | ||||
| def check_validate_every(validate_every): | |||||
| def check_evaluate_every(validate_every): | |||||
| if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | ||||
| raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
| if callable(validate_every): | if callable(validate_every): | ||||
| @@ -11,6 +11,7 @@ from fastNLP.core.collators.collator import _MultiCollator | |||||
| from fastNLP.core.utils.utils import indice_collate_wrapper | from fastNLP.core.utils.utils import indice_collate_wrapper | ||||
| from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler | |||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
| @@ -48,8 +49,8 @@ class TorchDataLoader(DataLoader): | |||||
| """ | """ | ||||
| def __init__(self, dataset, batch_size: int = 1, | def __init__(self, dataset, batch_size: int = 1, | ||||
| shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||||
| batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||||
| shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
| batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
| num_workers: int = 0, collate_fn: Optional[Callable] = None, | num_workers: int = 0, collate_fn: Optional[Callable] = None, | ||||
| pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
| timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
| @@ -380,7 +380,6 @@ class Driver(ABC): | |||||
| """ | """ | ||||
| # 单卡 driver 不需要这个函数; | # 单卡 driver 不需要这个函数; | ||||
| if self._pids is not None: | if self._pids is not None: | ||||
| exc_type, exc_value, exc_traceback_obj = sys.exc_info() | exc_type, exc_value, exc_traceback_obj = sys.exc_info() | ||||
| _write_exc_info = { | _write_exc_info = { | ||||
| 'exc_type': str(exc_type.__name__), | 'exc_type': str(exc_type.__name__), | ||||
| @@ -526,7 +526,7 @@ class TorchDDPDriver(TorchDriver): | |||||
| def barrier(self): | def barrier(self): | ||||
| if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | ||||
| torch.distributed.barrier(async_op=True) | |||||
| torch.distributed.barrier(async_op=False) | |||||
| def is_distributed(self): | def is_distributed(self): | ||||
| return True | return True | ||||
| @@ -9,8 +9,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from pathlib import Path | from pathlib import Path | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch | import torch | ||||
| from torch.utils.data import DataLoader, IterableDataset, RandomSampler, Sampler, BatchSampler, Dataset | |||||
| from torch.utils.data import DataLoader, IterableDataset, Sampler, BatchSampler, Dataset | |||||
| from torch.optim import Optimizer | from torch.optim import Optimizer | ||||
| from torch.utils.data import RandomSampler as TorchRandomSampler | |||||
| _reduces = { | _reduces = { | ||||
| 'sum': torch.max, | 'sum': torch.max, | ||||
| 'min': torch.min, | 'min': torch.min, | ||||
| @@ -30,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
| from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
| from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler | |||||
| class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
| @@ -211,8 +212,8 @@ class TorchDriver(Driver): | |||||
| states['sampler_states'] = sampler_states | states['sampler_states'] = sampler_states | ||||
| else: | else: | ||||
| raise RuntimeError( | |||||
| 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
| raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | |||||
| 'state.') | |||||
| # 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
| if should_save_model: | if should_save_model: | ||||
| @@ -283,6 +284,9 @@ class TorchDriver(Driver): | |||||
| sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
| elif isinstance(dataloader_args.sampler, ReproducibleSampler): | elif isinstance(dataloader_args.sampler, ReproducibleSampler): | ||||
| sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
| elif isinstance(dataloader_args.sampler, TorchRandomSampler): | |||||
| sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
| logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
| elif self.is_distributed(): | elif self.is_distributed(): | ||||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
| "`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
| @@ -19,7 +19,7 @@ class Accuracy(Metric): | |||||
| :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | ||||
| 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | ||||
| :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
| :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
| 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | ||||
| """ | """ | ||||
| super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
| @@ -84,6 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| :param rank: | :param rank: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
| f"number of samples({len(self.dataset)})." | |||||
| assert num_replicas>0 and isinstance(num_replicas, int) | assert num_replicas>0 and isinstance(num_replicas, int) | ||||
| assert isinstance(rank, int) and 0<=rank<num_replicas | assert isinstance(rank, int) and 0<=rank<num_replicas | ||||
| # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | ||||
| @@ -24,8 +24,8 @@ __all__ = [ | |||||
| 'indice_collate_wrapper', | 'indice_collate_wrapper', | ||||
| 'deprecated', | 'deprecated', | ||||
| 'seq_len_to_mask', | 'seq_len_to_mask', | ||||
| 'synchronize_safe_rm', | |||||
| 'synchronize_mkdir' | |||||
| 'rank_zero_rm', | |||||
| 'rank_zero_mkdir' | |||||
| ] | ] | ||||
| from .cache_results import cache_results | from .cache_results import cache_results | ||||
| @@ -37,6 +37,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device | |||||
| from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
| from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | ||||
| dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
| indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||||
| indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
| @@ -38,8 +38,8 @@ __all__ = [ | |||||
| 'indice_collate_wrapper', | 'indice_collate_wrapper', | ||||
| 'deprecated', | 'deprecated', | ||||
| 'seq_len_to_mask', | 'seq_len_to_mask', | ||||
| 'synchronize_safe_rm', | |||||
| 'synchronize_mkdir' | |||||
| 'rank_zero_rm', | |||||
| 'rank_zero_mkdir' | |||||
| ] | ] | ||||
| @@ -629,7 +629,7 @@ def wait_filepath(path, exist=True): | |||||
| def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||||
| def rank_zero_rm(path: Optional[Union[str, Path]]): | |||||
| """ | """ | ||||
| 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | ||||
| 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | ||||
| @@ -638,15 +638,14 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||||
| :param path: | :param path: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| if path is None: | |||||
| return | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| if not path.exists(): | |||||
| return | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
| if path is None: | |||||
| return | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| if not path.exists(): | |||||
| return | |||||
| _recursive_rm(path) | _recursive_rm(path) | ||||
| wait_filepath(path, exist=False) | |||||
| def _recursive_rm(path: Path): | def _recursive_rm(path: Path): | ||||
| @@ -662,21 +661,19 @@ def _recursive_rm(path: Path): | |||||
| path.rmdir() | path.rmdir() | ||||
| def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||||
| def rank_zero_mkdir(path: Optional[Union[str, Path]]): | |||||
| """ | """ | ||||
| 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | ||||
| 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | ||||
| """ | """ | ||||
| if path is None: | |||||
| return | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
| path.mkdir(parents=True, exist_ok=True) | |||||
| if path is None: | |||||
| return | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| wait_filepath(path, exist=True) | |||||
| path.mkdir(parents=True, exist_ok=True) | |||||
| def get_class_that_defined_method(method): | def get_class_that_defined_method(method): | ||||
| @@ -49,7 +49,7 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||||
| # 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | # 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | ||||
| FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | ||||
| # todo 注释 直接使用的变量 | |||||
| # 保存各种内容时的默认名称 | |||||
| FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | ||||
| FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | ||||
| FASTNLP_EVALUATE_RESULT_FILENAME = 'fastnlp_evaluate_results.json' | |||||
| @@ -7,13 +7,14 @@ from torch.optim import SGD | |||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| from pathlib import Path | from pathlib import Path | ||||
| import re | import re | ||||
| import time | |||||
| from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||||
| from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | |||||
| from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.core import synchronize_safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDatset | from tests.helpers.datasets.torch_data import TorchArgMaxDatset | ||||
| from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
| @@ -80,44 +81,21 @@ def test_model_checkpoint_callback_1( | |||||
| version, | version, | ||||
| only_state_dict | only_state_dict | ||||
| ): | ): | ||||
| # def test_model_checkpoint_callback_1( | |||||
| # model_and_optimizers: TrainerParameters, | |||||
| # driver='torch_ddp', | |||||
| # device=[0, 1], | |||||
| # version=1, | |||||
| # only_state_dict=True | |||||
| # ): | |||||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| try: | |||||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| if version == 0: | |||||
| callbacks = [ | |||||
| ModelCheckpointCallback( | |||||
| monitor="acc", | |||||
| save_folder=path, | |||||
| save_every_n_epochs=1, | |||||
| save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||||
| save_topk=None, | |||||
| save_last=False, | |||||
| save_on_exception=None, | |||||
| only_state_dict=only_state_dict | |||||
| ) | |||||
| ] | |||||
| elif version == 1: | |||||
| callbacks = [ | |||||
| ModelCheckpointCallback( | |||||
| monitor="acc", | |||||
| save_folder=path, | |||||
| save_every_n_epochs=3, | |||||
| save_every_n_batches=None, | |||||
| save_topk=2, | |||||
| save_last=True, | |||||
| save_on_exception=None, | |||||
| only_state_dict=only_state_dict | |||||
| ) | |||||
| ] | |||||
| if version == 0: | |||||
| callbacks = [ | |||||
| CheckpointCallback(folder=path, every_n_epochs=1, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||||
| monitor=None, only_state_dict=only_state_dict, save_object='model') | |||||
| ] | |||||
| elif version == 1: | |||||
| callbacks = [ | |||||
| CheckpointCallback(folder=path, every_n_epochs=3, every_n_batches=None, last=True, on_exceptions=None, topk=2, | |||||
| monitor="acc", only_state_dict=only_state_dict, save_object='model') | |||||
| ] | |||||
| try: | |||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
| driver=driver, | driver=driver, | ||||
| @@ -134,7 +112,7 @@ def test_model_checkpoint_callback_1( | |||||
| ) | ) | ||||
| trainer.run() | trainer.run() | ||||
| print("Finish train") | |||||
| all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
| # 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
| if version == 0: | if version == 0: | ||||
| @@ -217,8 +195,7 @@ def test_model_checkpoint_callback_1( | |||||
| trainer.run() | trainer.run() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| pass | |||||
| rank_zero_rm(path) | |||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -233,30 +210,23 @@ def test_model_checkpoint_callback_2( | |||||
| device, | device, | ||||
| only_state_dict | only_state_dict | ||||
| ): | ): | ||||
| path = Path.cwd().joinpath("test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| try: | |||||
| path = Path.cwd().joinpath("test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| from fastNLP.core.callbacks.callback_events import Events | |||||
| @Trainer.on(Events.on_train_epoch_end) | |||||
| def raise_exception(trainer): | |||||
| if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||||
| raise NotImplementedError | |||||
| callbacks = [ | |||||
| ModelCheckpointCallback( | |||||
| monitor="acc1", | |||||
| save_folder=path, | |||||
| save_every_n_epochs=None, | |||||
| save_every_n_batches=None, | |||||
| save_topk=None, | |||||
| save_last=False, | |||||
| save_on_exception=NotImplementedError, | |||||
| only_state_dict=only_state_dict | |||||
| ), | |||||
| ] | |||||
| from fastNLP.core.callbacks.callback_events import Events | |||||
| @Trainer.on(Events.on_train_epoch_end) | |||||
| def raise_exception(trainer): | |||||
| if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||||
| raise NotImplementedError | |||||
| callbacks = [ | |||||
| CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False, | |||||
| on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict, | |||||
| save_object='model'), | |||||
| ] | |||||
| try: | |||||
| with pytest.raises(NotImplementedError): | with pytest.raises(NotImplementedError): | ||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
| @@ -315,14 +285,14 @@ def test_model_checkpoint_callback_2( | |||||
| trainer.run() | trainer.run() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| # pass | # pass | ||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -333,37 +303,21 @@ def test_trainer_checkpoint_callback_1( | |||||
| version, | version, | ||||
| only_state_dict | only_state_dict | ||||
| ): | ): | ||||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| try: | |||||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| if version == 0: | |||||
| callbacks = [ | |||||
| TrainerCheckpointCallback( | |||||
| monitor="acc", | |||||
| save_folder=path, | |||||
| save_every_n_epochs=7, | |||||
| save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||||
| save_topk=None, | |||||
| save_last=False, | |||||
| save_on_exception=None, | |||||
| only_state_dict=only_state_dict | |||||
| ) | |||||
| ] | |||||
| elif version == 1: | |||||
| callbacks = [ | |||||
| TrainerCheckpointCallback( | |||||
| monitor="acc", | |||||
| save_folder=path, | |||||
| save_every_n_epochs=None, | |||||
| save_every_n_batches=None, | |||||
| save_topk=2, | |||||
| save_last=True, | |||||
| save_on_exception=None, | |||||
| only_state_dict=only_state_dict | |||||
| ) | |||||
| ] | |||||
| if version == 0: | |||||
| callbacks = [ | |||||
| CheckpointCallback(folder=path, every_n_epochs=7, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||||
| monitor=None, only_state_dict=only_state_dict, save_object='trainer') | |||||
| ] | |||||
| elif version == 1: | |||||
| callbacks = [ | |||||
| CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=True, on_exceptions=None, | |||||
| topk=2, monitor="acc", only_state_dict=only_state_dict, save_object='trainer') | |||||
| ] | |||||
| try: | |||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
| driver=driver, | driver=driver, | ||||
| @@ -461,8 +415,7 @@ def test_trainer_checkpoint_callback_1( | |||||
| trainer.run() | trainer.run() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| pass | |||||
| rank_zero_rm(path) | |||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -594,12 +547,12 @@ def test_trainer_checkpoint_callback_2( | |||||
| callbacks = [ | callbacks = [ | ||||
| TrainerCheckpointCallback( | TrainerCheckpointCallback( | ||||
| monitor="acc", | monitor="acc", | ||||
| save_folder=path, | |||||
| save_every_n_epochs=None, | |||||
| save_every_n_batches=50, | |||||
| save_topk=None, | |||||
| save_last=False, | |||||
| save_on_exception=None, | |||||
| folder=path, | |||||
| every_n_epochs=None, | |||||
| every_n_batches=50, | |||||
| topk=None, | |||||
| last=False, | |||||
| on_exception=None, | |||||
| model_save_fn=model_save_fn | model_save_fn=model_save_fn | ||||
| ) | ) | ||||
| ] | ] | ||||
| @@ -607,12 +560,12 @@ def test_trainer_checkpoint_callback_2( | |||||
| callbacks = [ | callbacks = [ | ||||
| TrainerCheckpointCallback( | TrainerCheckpointCallback( | ||||
| monitor="acc", | monitor="acc", | ||||
| save_folder=path, | |||||
| save_every_n_epochs=None, | |||||
| save_every_n_batches=None, | |||||
| save_topk=1, | |||||
| save_last=True, | |||||
| save_on_exception=None, | |||||
| folder=path, | |||||
| every_n_epochs=None, | |||||
| every_n_batches=None, | |||||
| topk=1, | |||||
| last=True, | |||||
| on_exception=None, | |||||
| model_save_fn=model_save_fn | model_save_fn=model_save_fn | ||||
| ) | ) | ||||
| ] | ] | ||||
| @@ -710,7 +663,7 @@ def test_trainer_checkpoint_callback_2( | |||||
| trainer.run() | trainer.run() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| # pass | # pass | ||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| @@ -0,0 +1,263 @@ | |||||
| """ | |||||
| 测试 more_evaluate_callback | |||||
| (1)能不能正确 evaluate ; | |||||
| (2) 能不能保存 topk 并load进来进行训练 | |||||
| """ | |||||
| import pytest | |||||
| import os | |||||
| import pytest | |||||
| from typing import Any | |||||
| from dataclasses import dataclass | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.optim import SGD | |||||
| import torch.distributed as dist | |||||
| from pathlib import Path | |||||
| import re | |||||
| from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
| from tests.helpers.utils import magic_argv_env_context | |||||
| from fastNLP.core import rank_zero_rm | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
| from torchmetrics import Accuracy | |||||
| from fastNLP.core.metrics import Metric | |||||
| from fastNLP.core.log import logger | |||||
| from fastNLP.core.callbacks import MoreEvaluateCallback | |||||
| @dataclass | |||||
| class ArgMaxDatasetConfig: | |||||
| num_labels: int = 10 | |||||
| feature_dimension: int = 10 | |||||
| data_num: int = 100 | |||||
| seed: int = 0 | |||||
| batch_size: int = 4 | |||||
| shuffle: bool = True | |||||
| @dataclass | |||||
| class TrainerParameters: | |||||
| model: Any = None | |||||
| optimizers: Any = None | |||||
| train_dataloader: Any = None | |||||
| evaluate_dataloaders: Any = None | |||||
| input_mapping: Any = None | |||||
| output_mapping: Any = None | |||||
| metrics: Any = None | |||||
| more_metrics: Any = None | |||||
| @pytest.fixture(scope="module", params=[0], autouse=True) | |||||
| def model_and_optimizers(request): | |||||
| trainer_params = TrainerParameters() | |||||
| trainer_params.model = TorchNormalModel_Classification_1( | |||||
| num_labels=ArgMaxDatasetConfig.num_labels, | |||||
| feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||||
| ) | |||||
| trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
| dataset = TorchArgMaxDatset( | |||||
| feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||||
| data_num=ArgMaxDatasetConfig.data_num, | |||||
| seed=ArgMaxDatasetConfig.seed | |||||
| ) | |||||
| _dataloader = DataLoader( | |||||
| dataset=dataset, | |||||
| batch_size=ArgMaxDatasetConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| class LossMetric(Metric): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.register_element('loss') | |||||
| def update(self, loss): | |||||
| self.loss += loss.item() | |||||
| def get_metric(self) -> dict: | |||||
| return self.loss.item() | |||||
| trainer_params.train_dataloader = _dataloader | |||||
| trainer_params.evaluate_dataloaders = _dataloader | |||||
| trainer_params.metrics = {'loss': LossMetric()} | |||||
| trainer_params.more_metrics = {"acc": Accuracy()} | |||||
| return trainer_params | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | |||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | |||||
| @magic_argv_env_context | |||||
| def test_model_more_evaluate_callback_1( | |||||
| model_and_optimizers: TrainerParameters, | |||||
| driver, | |||||
| device, | |||||
| version, | |||||
| only_state_dict | |||||
| ): | |||||
| try: | |||||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| if version == 0: | |||||
| callbacks = [ | |||||
| MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| metrics=model_and_optimizers.more_metrics, | |||||
| evaluate_every=-1, | |||||
| folder=path, topk=-1, | |||||
| topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | |||||
| ] | |||||
| elif version == 1: | |||||
| callbacks = [ | |||||
| MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| metrics=model_and_optimizers.more_metrics, | |||||
| evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||||
| folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||||
| save_object='model') | |||||
| ] | |||||
| n_epochs = 5 | |||||
| trainer = Trainer( | |||||
| model=model_and_optimizers.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=model_and_optimizers.optimizers, | |||||
| train_dataloader=model_and_optimizers.train_dataloader, | |||||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| input_mapping=model_and_optimizers.input_mapping, | |||||
| output_mapping=model_and_optimizers.output_mapping, | |||||
| metrics=model_and_optimizers.metrics, | |||||
| n_epochs=n_epochs, | |||||
| callbacks=callbacks, | |||||
| output_from_new_proc="all", | |||||
| evaluate_fn='train_step' | |||||
| ) | |||||
| trainer.run() | |||||
| all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
| # 检查生成保存模型文件的数量是不是正确的; | |||||
| if version == 0: | |||||
| assert len(all_saved_model_paths) == n_epochs | |||||
| elif version == 1: | |||||
| assert len(all_saved_model_paths) == 1 | |||||
| for folder in all_saved_model_paths: | |||||
| trainer = Trainer( | |||||
| model=model_and_optimizers.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=model_and_optimizers.optimizers, | |||||
| train_dataloader=model_and_optimizers.train_dataloader, | |||||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| input_mapping=model_and_optimizers.input_mapping, | |||||
| output_mapping=model_and_optimizers.output_mapping, | |||||
| metrics=model_and_optimizers.metrics, | |||||
| n_epochs=2, | |||||
| output_from_new_proc="all", | |||||
| evaluate_fn='train_step' | |||||
| ) | |||||
| folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | |||||
| trainer.load_model(folder, only_state_dict=only_state_dict) | |||||
| trainer.run() | |||||
| finally: | |||||
| rank_zero_rm(path) | |||||
| if dist.is_initialized(): | |||||
| dist.destroy_process_group() | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | |||||
| @pytest.mark.parametrize("only_state_dict", [True, False]) | |||||
| @magic_argv_env_context | |||||
| def test_trainer_checkpoint_callback_1( | |||||
| model_and_optimizers: TrainerParameters, | |||||
| driver, | |||||
| device, | |||||
| version, | |||||
| only_state_dict | |||||
| ): | |||||
| try: | |||||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
| path.mkdir(exist_ok=True, parents=True) | |||||
| if version == 0: | |||||
| callbacks = [ | |||||
| MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| metrics=model_and_optimizers.more_metrics, | |||||
| evaluate_every=-1, | |||||
| folder=path, topk=-1, | |||||
| topk_monitor='acc', only_state_dict=only_state_dict, save_object='trainer') | |||||
| ] | |||||
| elif version == 1: | |||||
| callbacks = [ | |||||
| MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| metrics=model_and_optimizers.more_metrics, | |||||
| evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||||
| folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||||
| save_object='trainer') | |||||
| ] | |||||
| n_epochs = 5 | |||||
| trainer = Trainer( | |||||
| model=model_and_optimizers.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=model_and_optimizers.optimizers, | |||||
| train_dataloader=model_and_optimizers.train_dataloader, | |||||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| input_mapping=model_and_optimizers.input_mapping, | |||||
| output_mapping=model_and_optimizers.output_mapping, | |||||
| metrics=model_and_optimizers.metrics, | |||||
| n_epochs=n_epochs, | |||||
| callbacks=callbacks, | |||||
| output_from_new_proc="all", | |||||
| evaluate_fn='train_step' | |||||
| ) | |||||
| trainer.run() | |||||
| all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
| # 检查生成保存模型文件的数量是不是正确的; | |||||
| if version == 0: | |||||
| assert len(all_saved_model_paths) == n_epochs | |||||
| elif version == 1: | |||||
| assert len(all_saved_model_paths) == 1 | |||||
| for folder in all_saved_model_paths: | |||||
| trainer = Trainer( | |||||
| model=model_and_optimizers.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=model_and_optimizers.optimizers, | |||||
| train_dataloader=model_and_optimizers.train_dataloader, | |||||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| input_mapping=model_and_optimizers.input_mapping, | |||||
| output_mapping=model_and_optimizers.output_mapping, | |||||
| metrics=model_and_optimizers.metrics, | |||||
| n_epochs=7, | |||||
| output_from_new_proc="all", | |||||
| evaluate_fn='train_step' | |||||
| ) | |||||
| folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | |||||
| trainer.load(folder, only_state_dict=only_state_dict) | |||||
| trainer.run() | |||||
| finally: | |||||
| rank_zero_rm(path) | |||||
| if dist.is_initialized(): | |||||
| dist.destroy_process_group() | |||||
| @@ -15,7 +15,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
| from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
| from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
| from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
| from fastNLP.core import synchronize_safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| @dataclass | @dataclass | ||||
| @@ -239,7 +239,7 @@ def test_trainer_output_from_new_proc( | |||||
| assert err_path.exists() | assert err_path.exists() | ||||
| path = Path(os.path.abspath(output_from_new_proc)) | path = Path(os.path.abspath(output_from_new_proc)) | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | ||||
| @@ -11,7 +11,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from fastNLP.core import synchronize_safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| import paddle | import paddle | ||||
| from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
| @@ -578,11 +578,11 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
| finally: | finally: | ||||
| if only_state_dict: | if only_state_dict: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| else: | else: | ||||
| synchronize_safe_rm(path + ".pdiparams") | |||||
| synchronize_safe_rm(path + ".pdiparams.info") | |||||
| synchronize_safe_rm(path + ".pdmodel") | |||||
| rank_zero_rm(path + ".pdiparams") | |||||
| rank_zero_rm(path + ".pdiparams.info") | |||||
| rank_zero_rm(path + ".pdmodel") | |||||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
| def test_save_and_load_with_randombatchsampler(only_state_dict): | def test_save_and_load_with_randombatchsampler(only_state_dict): | ||||
| @@ -652,7 +652,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): | |||||
| assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | ||||
| assert len(left_y_batches | already_seen_y_set) == len(dataset) | assert len(left_y_batches | already_seen_y_set) == len(dataset) | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
| def test_save_and_load_with_randomsampler(only_state_dict): | def test_save_and_load_with_randomsampler(only_state_dict): | ||||
| @@ -730,4 +730,4 @@ def test_save_and_load_with_randomsampler(only_state_dict): | |||||
| assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | ||||
| assert len(left_y_batches | already_seen_y_set) == len(dataset) | assert len(left_y_batches | already_seen_y_set) == len(dataset) | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| @@ -6,7 +6,7 @@ import logging | |||||
| import re | import re | ||||
| from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | ||||
| from fastNLP.core import synchronize_safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
| from tests.helpers.utils import magic_argv_env_context, recover_logger | from tests.helpers.utils import magic_argv_env_context, recover_logger | ||||
| @@ -56,7 +56,7 @@ def test_add_file_ddp_1_torch(): | |||||
| pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
| assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
| synchronize_safe_rm(filepath) | |||||
| rank_zero_rm(filepath) | |||||
| dist.barrier() | dist.barrier() | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -105,7 +105,7 @@ def test_add_file_ddp_2_torch(): | |||||
| pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
| assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| dist.barrier() | dist.barrier() | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -155,7 +155,7 @@ def test_add_file_ddp_3_torch(): | |||||
| pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
| assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
| synchronize_safe_rm(file) | |||||
| rank_zero_rm(file) | |||||
| dist.barrier() | dist.barrier() | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -202,7 +202,7 @@ def test_add_file_ddp_4_torch(): | |||||
| pattern = re.compile(msg) | pattern = re.compile(msg) | ||||
| assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| dist.barrier() | dist.barrier() | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -225,7 +225,7 @@ class TestLogger: | |||||
| line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
| assert self.msg in line | assert self.msg in line | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| @recover_logger | @recover_logger | ||||
| def test_add_file_2(self): | def test_add_file_2(self): | ||||
| @@ -243,7 +243,7 @@ class TestLogger: | |||||
| line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
| assert self.msg in line | assert self.msg in line | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(origin_path) | |||||
| rank_zero_rm(origin_path) | |||||
| @recover_logger | @recover_logger | ||||
| def test_add_file_3(self): | def test_add_file_3(self): | ||||
| @@ -279,7 +279,7 @@ class TestLogger: | |||||
| line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
| assert self.msg in line | assert self.msg in line | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| @recover_logger | @recover_logger | ||||
| def test_stdout(self, capsys): | def test_stdout(self, capsys): | ||||
| @@ -8,7 +8,7 @@ import sys | |||||
| from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
| from tests.helpers.common.utils import check_time_elapse | from tests.helpers.common.utils import check_time_elapse | ||||
| from fastNLP.core import synchronize_safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| def get_subprocess_results(cmd): | def get_subprocess_results(cmd): | ||||
| @@ -56,7 +56,7 @@ class TestCacheResults: | |||||
| res = demo() | res = demo() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(cache_fp) | |||||
| rank_zero_rm(cache_fp) | |||||
| def test_cache_save_refresh(self): | def test_cache_save_refresh(self): | ||||
| cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
| @@ -70,7 +70,7 @@ class TestCacheResults: | |||||
| with check_time_elapse(1, op='ge'): | with check_time_elapse(1, op='ge'): | ||||
| res = demo() | res = demo() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(cache_fp) | |||||
| rank_zero_rm(cache_fp) | |||||
| def test_cache_no_func_change(self): | def test_cache_no_func_change(self): | ||||
| cache_fp = os.path.abspath('demo.pkl') | cache_fp = os.path.abspath('demo.pkl') | ||||
| @@ -91,7 +91,7 @@ class TestCacheResults: | |||||
| with check_time_elapse(1, op='lt'): | with check_time_elapse(1, op='lt'): | ||||
| res = demo() | res = demo() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm('demo.pkl') | |||||
| rank_zero_rm('demo.pkl') | |||||
| def test_cache_func_change(self, capsys): | def test_cache_func_change(self, capsys): | ||||
| cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
| @@ -121,7 +121,7 @@ class TestCacheResults: | |||||
| assert 'is different from its last cache' not in output[0] | assert 'is different from its last cache' not in output[0] | ||||
| finally: | finally: | ||||
| synchronize_safe_rm('demo.pkl') | |||||
| rank_zero_rm('demo.pkl') | |||||
| def test_cache_check_hash(self): | def test_cache_check_hash(self): | ||||
| cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
| @@ -152,7 +152,7 @@ class TestCacheResults: | |||||
| assert 'is different from its last cache' in output[0] | assert 'is different from its last cache' in output[0] | ||||
| finally: | finally: | ||||
| synchronize_safe_rm('demo.pkl') | |||||
| rank_zero_rm('demo.pkl') | |||||
| # 外部 function 改变也会 导致改变 | # 外部 function 改变也会 导致改变 | ||||
| def test_refer_fun_change(self): | def test_refer_fun_change(self): | ||||
| @@ -177,7 +177,7 @@ class TestCacheResults: | |||||
| assert 'is different from its last cache' in res | assert 'is different from its last cache' in res | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(cache_fp) | |||||
| rank_zero_rm(cache_fp) | |||||
| # 外部 method 改变也会 导致改变 | # 外部 method 改变也会 导致改变 | ||||
| def test_refer_class_method_change(self): | def test_refer_class_method_change(self): | ||||
| @@ -202,7 +202,7 @@ class TestCacheResults: | |||||
| assert 'is different from its last cache' in res | assert 'is different from its last cache' in res | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(cache_fp) | |||||
| rank_zero_rm(cache_fp) | |||||
| def test_duplicate_keyword(self): | def test_duplicate_keyword(self): | ||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| @@ -240,7 +240,7 @@ class TestCacheResults: | |||||
| results = cache() | results = cache() | ||||
| assert (1, 2) == results | assert (1, 2) == results | ||||
| finally: | finally: | ||||
| synchronize_safe_rm('demo/') | |||||
| rank_zero_rm('demo/') | |||||
| def test_result_none_error(self): | def test_result_none_error(self): | ||||
| @cache_results('demo.pkl') | @cache_results('demo.pkl') | ||||
| @@ -251,7 +251,7 @@ class TestCacheResults: | |||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| results = cache() | results = cache() | ||||
| finally: | finally: | ||||
| synchronize_safe_rm('demo.pkl') | |||||
| rank_zero_rm('demo.pkl') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -2,7 +2,7 @@ import os | |||||
| from fastNLP.envs.set_backend import dump_fastnlp_backend | from fastNLP.envs.set_backend import dump_fastnlp_backend | ||||
| from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
| from fastNLP.core import synchronize_safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| def test_dump_fastnlp_envs(): | def test_dump_fastnlp_envs(): | ||||
| @@ -14,4 +14,4 @@ def test_dump_fastnlp_envs(): | |||||
| assert filepath in output[0] | assert filepath in output[0] | ||||
| assert os.path.exists(filepath) | assert os.path.exists(filepath) | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(filepath) | |||||
| rank_zero_rm(filepath) | |||||
| @@ -9,7 +9,7 @@ import numpy as np | |||||
| from fastNLP.modules.mix_modules.mix_module import MixModule | from fastNLP.modules.mix_modules.mix_module import MixModule | ||||
| from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | ||||
| from fastNLP.core import synchronize_safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| ############################################################################ | ############################################################################ | ||||
| @@ -227,7 +227,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
| self.assertDictEqual(state_dict, new_state_dict) | self.assertDictEqual(state_dict, new_state_dict) | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | |||||
| rank_zero_rm(path) | |||||
| def if_device_correct(self, device): | def if_device_correct(self, device): | ||||