@@ -4,8 +4,7 @@ __all__ = [ | |||
'EventsList', | |||
'Filter', | |||
'CallbackManager', | |||
'ModelCheckpointCallback', | |||
'TrainerCheckpointCallback', | |||
'CheckpointCallback', | |||
'choose_progress_callback', | |||
'ProgressCallback', | |||
'RichCallback', | |||
@@ -13,18 +12,21 @@ __all__ = [ | |||
'LoadBestModelCallback', | |||
"EarlyStopCallback", | |||
'MoreEvaluateCallback', | |||
"TorchWarmupCallback", | |||
"TorchGradClipCallback" | |||
"TorchGradClipCallback", | |||
] | |||
from .callback import Callback | |||
from .callback_events import EventsList, Events, Filter | |||
from .callback_manager import CallbackManager | |||
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||
from .checkpoint_callback import CheckpointCallback | |||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | |||
from .lr_scheduler_callback import LRSchedCallback | |||
from .load_best_model_callback import LoadBestModelCallback | |||
from .early_stop_callback import EarlyStopCallback | |||
from .torch_callbacks import * | |||
from .more_evaluate_callback import MoreEvaluateCallback | |||
@@ -236,7 +236,7 @@ class Callback: | |||
结束 validate 时调用,并把 validate 的结果传入。 | |||
:param trainer: | |||
:param results: | |||
:param results: Evaluate 的结果,一般是个 dict 。 | |||
:return: | |||
""" | |||
pass | |||
@@ -250,6 +250,15 @@ class Callback: | |||
""" | |||
return self.__class__.__name__ | |||
@property | |||
def need_reproducible_sampler(self) -> bool: | |||
""" | |||
当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。 | |||
:return: | |||
""" | |||
return False | |||
class _CallbackWrapper(Callback): | |||
""" | |||
@@ -8,7 +8,6 @@ __all__ = [ | |||
from .callback_events import Events | |||
from .callback import Callback | |||
from .checkpoint_callback import TrainerCheckpointCallback | |||
from .progress_callback import ProgressCallback, choose_progress_callback | |||
from fastNLP.core.log import logger | |||
@@ -45,7 +44,7 @@ class CallbackManager: | |||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | |||
""" | |||
self._has_trainer_checkpoint = False | |||
self._need_reproducible_sampler = False | |||
_has_progress_callback = False | |||
_callbacks = [] | |||
@@ -98,8 +97,7 @@ class CallbackManager: | |||
:return: | |||
""" | |||
for each_callback in self.class_callbacks: | |||
if isinstance(each_callback, TrainerCheckpointCallback): | |||
self._has_trainer_checkpoint = True | |||
self._need_reproducible_sampler |= each_callback.need_reproducible_sampler | |||
self.dissect_one_callback(each_callback) | |||
def dissect_one_callback(self, callback: Callback): | |||
@@ -211,7 +209,7 @@ class CallbackManager: | |||
@property | |||
def has_trainer_checkpoint(self) -> bool: | |||
return self._has_trainer_checkpoint | |||
return self._need_reproducible_sampler | |||
@_transfer | |||
def on_after_trainer_initialized(self, trainer): | |||
@@ -1,339 +1,151 @@ | |||
__all__ = [ | |||
'ModelCheckpointCallback', | |||
'TrainerCheckpointCallback' | |||
'CheckpointCallback' | |||
] | |||
import os | |||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | |||
from typing import Union, Optional, Callable, Dict, Sequence | |||
from pathlib import Path | |||
import sys | |||
from copy import deepcopy | |||
import fastNLP | |||
from .has_monitor_callback import HasMonitorCallback | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK | |||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
from .topk_saver import TopkSaver | |||
from .callback import Callback | |||
class CheckpointCallback(HasMonitorCallback): | |||
def __init__( | |||
self, | |||
monitor:Optional[Union[str, Callable]]=None, | |||
save_folder: Optional[Union[str, Path]] = None, | |||
save_every_n_epochs: Optional[int] = None, | |||
save_every_n_batches: Optional[int] = None, | |||
save_last: bool = False, | |||
save_topk: Optional[int] = None, | |||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | |||
larger_better: bool = True, | |||
only_state_dict: bool = True, | |||
model_save_fn: Optional[Callable] = None, | |||
**kwargs, | |||
): | |||
class CheckpointCallback(Callback): | |||
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | |||
every_n_batches: Optional[int] = None, last: bool = False, | |||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0, | |||
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | |||
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | |||
save_evaluate_results=True, **kwargs): | |||
""" | |||
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 | |||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||
- folder/ | |||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
- {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 | |||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 | |||
- {save_object}-last/ # 最后一个 epoch 的保存 | |||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
:param save_topk: 保存 monitor 结果 topK 个。 | |||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||
:param every_n_epochs: 多少个 epoch 保存一次。 | |||
:param every_n_batches: 多少个 batch 保存一次。 | |||
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
:param topk: 保存 monitor 结果 topK 个。 | |||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||
:param larger_better: monitor 的值是否时越大越好。 | |||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||
:param kwargs: | |||
""" | |||
super().__init__(monitor=monitor, larger_better=larger_better, | |||
must_have_monitor=save_topk is not None) | |||
if save_folder is None: | |||
super().__init__() | |||
if folder is None: | |||
logger.warning( | |||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||
save_folder = Path.cwd() | |||
save_folder = Path(save_folder) | |||
if not save_folder.exists(): | |||
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | |||
elif save_folder.is_file(): | |||
raise ValueError("Parameter `save_folder` should be a directory instead of a file.") | |||
if save_every_n_epochs is not None: | |||
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: | |||
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") | |||
"Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||
folder = Path.cwd() | |||
folder = Path(folder) | |||
if not folder.exists(): | |||
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") | |||
elif folder.is_file(): | |||
raise ValueError("Parameter `folder` should be a directory instead of a file.") | |||
if every_n_epochs is not None: | |||
if not isinstance(every_n_epochs, int) or every_n_epochs < 1: | |||
raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.") | |||
else: | |||
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||
every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||
if save_every_n_batches is not None: | |||
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: | |||
raise ValueError( | |||
"parameter save_every_n_batches should be an int and greater than or equal to 1.") | |||
if every_n_batches is not None: | |||
if not isinstance(every_n_batches, int) or every_n_batches < 1: | |||
raise ValueError("Parameter `every_n_batches` should be an int and greater than or equal to 1.") | |||
else: | |||
save_every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||
every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||
if save_topk is not None: | |||
if not isinstance(save_topk, int) or save_topk < 1: | |||
raise ValueError("parameter save_topk should be an int and greater than or equal to 1.") | |||
if topk is not None: | |||
if not isinstance(topk, int): | |||
raise ValueError("Parameter `topk` should be an int.") | |||
else: | |||
topk = 0 | |||
if save_on_exception is not None: | |||
if not isinstance(save_on_exception, Sequence): | |||
save_on_exception = [save_on_exception] | |||
if on_exceptions is not None: | |||
if not isinstance(on_exceptions, Sequence): | |||
on_exceptions = [on_exceptions] | |||
for exception in save_on_exception: | |||
for exception in on_exceptions: | |||
if not issubclass(exception, BaseException): | |||
raise TypeError("Each exception in parameter `save_on_exception` can only be " | |||
raise TypeError("Each exception in parameter `on_exception` can only be " | |||
"`BaseException` type.") | |||
else: | |||
save_on_exception = [] | |||
on_exceptions = [] | |||
self.save_folder = save_folder | |||
self.save_every_n_epochs = save_every_n_epochs | |||
self.save_every_n_batches = save_every_n_batches | |||
self.save_last = save_last | |||
self.save_topk = save_topk | |||
self.only_state_dict = only_state_dict | |||
self.model_save_fn = model_save_fn | |||
self.save_on_exception = save_on_exception | |||
self.kwargs = kwargs | |||
self.topk_saver = TopkSaver(topk, monitor, larger_better, folder, only_state_dict, | |||
model_save_fn, save_evaluate_results, | |||
save_object, **kwargs) | |||
self.topk = topk | |||
self.save_object = save_object | |||
# 这些参数是专门留给 topk 模式专门使用的; | |||
self._topk_model = {} | |||
self._topn = 0 # 表示目前已经保存了几个最好的模型; | |||
self.every_n_epochs = every_n_epochs | |||
self.every_n_batches = every_n_batches | |||
self.last = last | |||
self.exceptions = on_exceptions | |||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
# 该 folder 只在保存真的要发生的时候再创建。 | |||
@property | |||
def need_reproducible_sampler(self) -> bool: | |||
return self.save_object == 'trainer' | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
if self.save_topk is not None: | |||
super().on_after_trainer_initialized(trainer, driver) | |||
if self.save_topk is not None and trainer.evaluator is None: | |||
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | |||
if self.topk_saver.topk_queue: # 需要设置 monitor | |||
if self.topk_saver.monitor is None: | |||
self.topk_saver.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||
if self.topk_saver.topk_queue and trainer.evaluator is None: | |||
logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") | |||
def on_validate_end(self, trainer, results): | |||
self._save_topk(trainer, results) | |||
# 如果发生了保存,则返回的 folder 不为 None | |||
folder = self.topk_saver.save_topk(trainer, results) | |||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | |||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | |||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' | |||
self.save(trainer, folder_name=folder_name) | |||
if self.save_last: | |||
folder_name = f'{self.folder_prefix}-last' | |||
self.save(trainer, folder_name=folder_name) | |||
if trainer.cur_epoch_idx % self.every_n_epochs == 0: | |||
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}' | |||
self.topk_saver.save(trainer, folder_name=folder_name) | |||
if self.last: | |||
folder_name = f'{self.save_object}-last' | |||
self.topk_saver.save(trainer, folder_name=folder_name) | |||
def on_train_batch_end(self, trainer): | |||
if trainer.global_forward_batches % self.save_every_n_batches == 0: | |||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||
self.save(trainer, folder_name=folder_name) | |||
if trainer.global_forward_batches % self.every_n_batches == 0: | |||
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||
self.topk_saver.save(trainer, folder_name=folder_name) | |||
def on_exception(self, trainer, exception: BaseException): | |||
if exception.__class__ in self.save_on_exception: | |||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||
f'exception_{exception.__class__.__name__}' | |||
self.save(trainer=trainer, folder_name=folder_name) | |||
if exception.__class__ in self.exceptions: | |||
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||
f'exception_{exception.__class__.__name__}' | |||
self.topk_saver.save(trainer, folder_name=folder_name) | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
""" | |||
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | |||
topk_model的状态 | |||
_real_monitor的值 | |||
保存状态,以便之后可以继续使用 | |||
""" | |||
states = {} | |||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||
states['_topk_model'] = deepcopy(self._topk_model) | |||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||
if isinstance(self._real_monitor, str): | |||
states['_real_monitor'] = self._real_monitor | |||
states['topk_saver'] = self.topk_saver.state_dict() | |||
return states | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
timestamp_path = states['timestamp_path'] | |||
if not os.path.exists(timestamp_path): | |||
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " | |||
f" {self.timestamp_path.absolute()}.") | |||
else: | |||
logger.info(f"Resume to checkpoint in path: {timestamp_path}.") | |||
self.timestamp_path = Path(timestamp_path) | |||
_topk_model = states['_topk_model'] | |||
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | |||
if save_topk is not None and self.save_topk is not None: | |||
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | |||
f"as {save_topk}." | |||
self._topk_model.update(self._topk_model) | |||
self._real_monitor = states["_real_monitor"] | |||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||
""" | |||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | |||
:param trainer: | |||
:param results: | |||
:return: | |||
""" | |||
if self.save_topk is not None: | |||
monitor_value = self.get_monitor_value(results=results) | |||
if monitor_value is None: | |||
return | |||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||
f"-{self._real_monitor}_{monitor_value}" | |||
_should_save = False | |||
if self._topn < self.save_topk: | |||
self._topk_model[folder_name] = monitor_value | |||
self._topn += 1 | |||
_should_save = True | |||
else: | |||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||
key=lambda x: self._topk_model[x]) | |||
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||
self._topk_model[folder_name] = monitor_value | |||
_should_save = True | |||
self._topk_model.pop(_least_valuable_model) | |||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | |||
assert len(self._topk_model) == self.save_topk == self._topn | |||
if _should_save: | |||
self.save(trainer, folder_name=folder_name) | |||
def save(self, trainer, folder_name): | |||
""" | |||
执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 | |||
:param trainer: | |||
:param folder_name: | |||
:return: | |||
""" | |||
folder = self.timestamp_path.joinpath(folder_name) | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建 | |||
synchronize_mkdir(folder) | |||
_fn = getattr(trainer, self.save_fn_name) | |||
_fn( | |||
folder=folder, | |||
only_state_dict=self.only_state_dict, | |||
model_save_fn=self.model_save_fn, | |||
**self.kwargs | |||
) | |||
@property | |||
def folder_prefix(self): | |||
raise NotImplementedError("The `folder_prefix` is not specified") | |||
@property | |||
def save_fn_name(self): | |||
raise NotImplementedError("The `save_fn_name` is not specified.") | |||
class ModelCheckpointCallback(CheckpointCallback): | |||
""" | |||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||
- save_folder/ | |||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||
- model-last/ # 最后一个 epoch 的保存 | |||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
topk_saver_states = states['topk_saver'] | |||
self.topk_saver.load_state_dict(topk_saver_states) | |||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
:param save_topk: 保存 monitor 结果 topK 个。 | |||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||
:param larger_better: monitor 的值是否时越大越好。 | |||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param kwargs: | |||
""" | |||
@property | |||
def save_fn_name(self): | |||
""" | |||
调用 Trainer 中的哪个函数。 | |||
:return: | |||
""" | |||
return 'save_model' | |||
@property | |||
def callback_name(self): | |||
""" | |||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
:return: | |||
""" | |||
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
@property | |||
def folder_prefix(self): | |||
return 'model' | |||
class TrainerCheckpointCallback(CheckpointCallback): | |||
""" | |||
保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||
- save_folder/ | |||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||
- trainer-last/ # 最后一个 epoch 的保存 | |||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | |||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
:param save_every_n_batches: 多少个 batch 保存一次。 | |||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
:param save_topk: 保存 monitor 结果 topK 个。 | |||
:param save_on_exception: 在出异常信息时,是否保存。 | |||
:param larger_better: monitor 的值是否时越大越好。 | |||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param kwargs: | |||
""" | |||
@property | |||
def save_fn_name(self): | |||
""" | |||
调用 Trainer 中的哪个函数。 | |||
:return: | |||
""" | |||
return 'save' | |||
@property | |||
def callback_name(self): | |||
""" | |||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
:return: | |||
""" | |||
return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
@property | |||
def folder_prefix(self): | |||
return 'trainer' |
@@ -1,10 +1,12 @@ | |||
__all__ = [ | |||
'HasMonitorCallback', | |||
'ExecuteOnceBetterMonitor' | |||
'ExecuteOnceBetterMonitor', | |||
'MonitorUtility' | |||
] | |||
from typing import Dict, Union, Any | |||
from abc import ABC | |||
import functools | |||
from fastNLP.core.utils import apply_to_collection | |||
from fastNLP.core.callbacks import Callback | |||
@@ -27,21 +29,13 @@ class CanItemDataType(ABC): | |||
return NotImplemented | |||
class MonitorUtility: | |||
""" | |||
计算 monitor 的相关函数 | |||
class HasMonitorCallback(Callback): | |||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||
""" | |||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
:param larger_better: monitor 是否时越大越好 | |||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||
""" | |||
""" | |||
def __init__(self, monitor, larger_better): | |||
self.set_monitor(monitor, larger_better) | |||
self.must_have_moinitor = must_have_monitor | |||
def set_monitor(self, monitor, larger_better): | |||
if callable(monitor): # 检查是否能够接受一个参数 | |||
@@ -57,26 +51,14 @@ class HasMonitorCallback(Callback): | |||
self.monitor_value = float('inf') | |||
self._real_monitor = self.monitor | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
def itemize_results(self, results): | |||
""" | |||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||
将结果中有 .item() 方法的都调用一下,使得可以结果可以保存 | |||
:param trainer: | |||
:param driver: | |||
:param results: | |||
:return: | |||
""" | |||
if self.monitor is None and trainer.monitor is not None: | |||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||
if self.must_have_moinitor and self.monitor is None: | |||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||
f"You can set it in the initialization or through Trainer.") | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
# 主要核对一下 monitor 是否存在。 | |||
if self.monitor is not None: | |||
self.get_monitor_value(results=sanity_check_res) | |||
return apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||
def get_monitor_value(self, results:Dict)->Union[float, None]: | |||
""" | |||
@@ -85,10 +67,10 @@ class HasMonitorCallback(Callback): | |||
:param results: | |||
:return: 如果为 None ,表明此次没有找到合适的monitor | |||
""" | |||
if len(results)==0: | |||
if len(results) == 0 or self.monitor is None: | |||
return None | |||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | |||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||
results = self.itemize_results(results) | |||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
@@ -97,7 +79,7 @@ class HasMonitorCallback(Callback): | |||
# 第一次运行 | |||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||
# 检测到此次和上次不同。 | |||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
@@ -165,7 +147,10 @@ class HasMonitorCallback(Callback): | |||
""" | |||
if callable(self.monitor): | |||
try: | |||
monitor_name = self.monitor.__qualname__ | |||
monitor = self.monitor | |||
while isinstance(monitor, functools.partial): | |||
monitor = monitor.func | |||
monitor_name = monitor.__qualname__ | |||
except: | |||
monitor_name = self.monitor.__name__ | |||
elif self.monitor is None: | |||
@@ -176,6 +161,46 @@ class HasMonitorCallback(Callback): | |||
return monitor_name | |||
class HasMonitorCallback(MonitorUtility, Callback): | |||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||
""" | |||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | |||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||
:param larger_better: monitor 是否时越大越好 | |||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | |||
""" | |||
super().__init__(monitor, larger_better) | |||
self.must_have_monitor = must_have_monitor | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
""" | |||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||
:param trainer: | |||
:param driver: | |||
:return: | |||
""" | |||
if self.monitor is None and trainer.monitor is not None: | |||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||
if self.must_have_monitor and self.monitor is None: | |||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||
f"You can set it in the initialization or through Trainer.") | |||
if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None: | |||
raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.__class__.__name__}" | |||
f" need to watch the monitor:`{self.monitor_name}`.") | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
# 主要核对一下 monitor 是否存在。 | |||
if self.monitor is not None: | |||
self.get_monitor_value(results=sanity_check_res) | |||
class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||
def __init__(self, monitor, larger_better, execute_fn): | |||
""" | |||
@@ -183,13 +208,13 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||
:param larger_better: monitor 是否时越大越好 | |||
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 | |||
""" | |||
super().__init__(monitor, larger_better, must_have_monitor=True) | |||
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | |||
self.execute_fn = execute_fn() | |||
self.execute_fn = execute_fn | |||
def on_validate_end(self, trainer, results): | |||
if self.is_better_results(results): |
@@ -23,7 +23,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||
:param larger_better: 该 metric 值是否是越大越好。 | |||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||
@@ -72,7 +72,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | |||
except NotImplementedError: | |||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | |||
f"save best model when launch using script.") | |||
f"save best model when launch using module.") | |||
super().on_after_trainer_initialized(trainer, driver) | |||
@@ -87,7 +87,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
def on_train_end(self, trainer): | |||
logger.info(f"Loading best model with {self._real_monitor}: {self.monitor_value}...") | |||
logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") | |||
if self.real_save_folder: | |||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
model_load_fn=self.model_load_fn) | |||
@@ -0,0 +1,174 @@ | |||
__all__ = [ | |||
'MoreEvaluateCallback' | |||
] | |||
import os | |||
from typing import Union, Callable, Optional, Dict | |||
from fastNLP.core.log import logger | |||
from .has_monitor_callback import HasMonitorCallback | |||
from .topk_saver import TopkSaver | |||
class MoreEvaluateCallback(HasMonitorCallback): | |||
def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1, | |||
watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True, | |||
evaluate_fn=None, num_eval_sanity_batch=2, | |||
topk=0, topk_monitor=None, topk_larger_better=True, | |||
folder=None, only_state_dict=True, save_object='model', model_save_fn=None, | |||
save_evaluate_results=True, save_kwargs=None, | |||
**kwargs): | |||
""" | |||
当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||
如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存 | |||
- folder/ | |||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
:param dataloaders: 需要评估的数据 | |||
:param metrics: 使用的 metrics 。 | |||
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch | |||
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||
一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | |||
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | |||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | |||
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 | |||
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | |||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||
:param watch_monitor_larger_better: watch_monitor 是否越大越好。 | |||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||
:param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。 | |||
:param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次 | |||
evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。 | |||
:param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | |||
:param topk_larger_better: topk_monitor 的值是否时越大越好。 | |||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 | |||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | |||
:param save_kwargs: dict。更多的保存相关的参数。 | |||
:param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。请特别留意 evaluate_fn 的设置。 | |||
""" | |||
super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, | |||
must_have_monitor=False) | |||
if watch_monitor is None and evaluate_every is None: | |||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | |||
if watch_monitor is not None and evaluate_every is not None: | |||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | |||
self.watch_monitor = watch_monitor | |||
if topk_monitor is not None and topk == 0: | |||
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | |||
if topk != 0 and topk_monitor is None: | |||
raise RuntimeError("`topk` is set, but `topk_monitor` is None.") | |||
assert save_object in ['trainer', 'model'] | |||
self.dataloaders = dataloaders | |||
self.metrics = metrics | |||
self.evaluate_every = evaluate_every | |||
self.evaluate_fn = evaluate_fn | |||
self.num_eval_sanity_batch = num_eval_sanity_batch | |||
if save_kwargs is None: | |||
save_kwargs = {} | |||
self.topk_saver = TopkSaver(topk=topk, monitor=topk_monitor, larger_better=topk_larger_better, | |||
folder=folder, only_state_dict=only_state_dict, | |||
model_save_fn=model_save_fn, save_evaluate_results=save_evaluate_results, | |||
save_object=save_object, **save_kwargs) | |||
self.kwargs = kwargs | |||
@property | |||
def need_reproducible_sampler(self) -> bool: | |||
return self.topk_saver.save_object == 'trainer' | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
# 如果是需要 watch 的,不能没有 evaluator | |||
if self.watch_monitor is not None: | |||
assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \ | |||
f"evaluate_dataloaders is provided in Trainer." | |||
if trainer.evaluate_fn is self.evaluate_fn: | |||
logger.warning_once("The `evaluate_fn` is the same as in Trainer, there seems no need to use " | |||
"`MoreEvaluateCallback`.") | |||
# 初始化 evaluator , 同时避免调用 super 对 monitor 赋值 | |||
kwargs = { | |||
'model': self.kwargs.get('model', trainer.model), | |||
'dataloaders': self.dataloaders, | |||
'metrics': self.metrics, | |||
'driver': self.kwargs.get('driver', trainer.driver), | |||
'device': self.kwargs.get('device', trainer.device), | |||
'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn), | |||
'evaluate_fn': self.evaluate_fn, | |||
'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), | |||
'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), | |||
'fp16': self.kwargs.get('fp16', trainer.fp16), | |||
'use_dist_sampler': self.kwargs.get('use_dist_sampler', | |||
trainer.kwargs.get('eval_use_dist_sampler', None)), | |||
'progress_bar': self.kwargs.get('progress_bar', trainer.kwargs.get('progress_bar', 'auto')), | |||
'verbose': self.kwargs.get('verbose', 1) | |||
} | |||
for key, value in self.kwargs.items(): | |||
if key not in kwargs: | |||
kwargs[key] = value | |||
from fastNLP.core.controllers.evaluator import Evaluator | |||
self.evaluator = Evaluator(**kwargs) | |||
if self.num_eval_sanity_batch>0: | |||
results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) | |||
self.topk_saver.get_monitor_value(results) | |||
def on_validate_end(self, trainer, results): | |||
if self.is_better_results(results, keep_if_better=True): | |||
results = self.evaluator.run() | |||
self.topk_saver.save_topk(trainer, results) | |||
def on_train_epoch_end(self, trainer): | |||
if self.watch_monitor is not None: | |||
return | |||
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | |||
validate_every = -self.evaluate_every | |||
if trainer.cur_epoch_idx % validate_every == 0: | |||
results = self.evaluator.run() | |||
self.topk_saver.save_topk(trainer, results) | |||
def on_train_batch_end(self, trainer): | |||
if self.watch_monitor is not None: | |||
return | |||
if callable(self.evaluate_every): | |||
if self.evaluate_every(self): | |||
results = self.evaluator.run() | |||
self.topk_saver.save_topk(trainer, results) | |||
elif self.evaluate_every > 0 and trainer.global_forward_batches % self.evaluate_every == 0: | |||
results = self.evaluator.run() | |||
self.topk_saver.save_topk(trainer, results) | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
states = {'topk_saver': self.topk_saver.state_dict()} | |||
if isinstance(self._real_monitor, str): | |||
states['_real_monitor'] = self._real_monitor | |||
states['monitor_value'] = self.monitor_value | |||
return states | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
topk_saver_states = states['topk_saver'] | |||
self.topk_saver.load_state_dict(topk_saver_states) | |||
if '_real_monitor' in states: | |||
self._real_monitor = states["_real_monitor"] | |||
self.monitor_value = states['monitor_value'] | |||
@property | |||
def callback_name(self): | |||
metric_names = '+'.join(sorted(self.metrics.keys())) | |||
return f'more_evaluate_callback#metric_name-{metric_names}#monitor-{self.monitor_name}#topk_saver:{self.topk_saver}' | |||
@@ -9,7 +9,6 @@ __all__ = [ | |||
] | |||
from .has_monitor_callback import HasMonitorCallback | |||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||
from fastNLP.core.utils import f_rich_progress | |||
from fastNLP.core.log import logger | |||
@@ -42,7 +41,8 @@ class RichCallback(ProgressCallback): | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||
相关的 monitor 值请返回 None 。 | |||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||
:param format_json: 是否格式化 json 再打印 | |||
""" | |||
@@ -135,7 +135,8 @@ class RawTextCallback(ProgressCallback): | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||
相关的 monitor 值请返回 None 。 | |||
:param larger_better: 是否是monitor的结果越大越好。 | |||
:param format_json: 是否format json再打印 | |||
""" | |||
@@ -0,0 +1,246 @@ | |||
import json | |||
import os | |||
from copy import deepcopy | |||
from pathlib import Path | |||
from typing import Optional, Dict, Tuple | |||
from fastNLP.core.utils import rank_zero_rm | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME | |||
from .has_monitor_callback import MonitorUtility | |||
class Saver: | |||
def __init__(self, folder, only_state_dict, model_save_fn, **kwargs): | |||
""" | |||
执行保存的对象。保存的文件组织结构为 | |||
- folder # 当前初始化的参数 | |||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
- folder_name # 由 save() 调用时传入。 | |||
:param folder: | |||
:param only_state_dict: | |||
:param model_save_fn: | |||
:param kwargs: | |||
""" | |||
if folder is None: | |||
logger.warning( | |||
"Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||
folder = Path.cwd() | |||
folder = Path(folder) | |||
if not folder.exists(): | |||
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") | |||
elif folder.is_file(): | |||
raise ValueError("Parameter `folder` should be a directory instead of a file.") | |||
self.folder = folder | |||
self.only_state_dict = only_state_dict | |||
self.model_save_fn = model_save_fn | |||
self.kwargs = kwargs | |||
self.eval_results = kwargs.get('eval_results', True) | |||
self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
@rank_zero_call | |||
def save(self, save_fn, folder_name): | |||
""" | |||
执行保存的函数,将数据保存在 folder/timestamp/folder_name 下。其中 folder 为用户在初始化指定, | |||
timestamp 为当前脚本的启动时间。 | |||
:param save_fn: 调用的保存函数,应该可接受参数 folder:str, only_state_dict: bool, model_save_fn: callable, kwargs | |||
:param folder_name: 保存的 folder 名称,将被创建。 | |||
:return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。 | |||
""" | |||
folder = self.timestamp_path.joinpath(folder_name) | |||
folder.mkdir(parents=True, exist_ok=True) | |||
save_fn( | |||
folder=folder, | |||
only_state_dict=self.only_state_dict, | |||
model_save_fn=self.model_save_fn, | |||
**self.kwargs | |||
) | |||
return str(os.path.abspath(folder)) | |||
@rank_zero_call | |||
def save_json(self, results, path): | |||
""" | |||
以 json 格式保存 results 到 path 中 | |||
:param results: | |||
:param path: | |||
:return: | |||
""" | |||
with open(path, 'w', encoding='utf8') as f: | |||
json.dump(results, f, indent=2) | |||
@rank_zero_call | |||
def rm(self, folder_name): | |||
""" | |||
移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。 | |||
:param folder_name: | |||
:return: | |||
""" | |||
folder = self.timestamp_path.joinpath(folder_name) | |||
rank_zero_rm(folder) | |||
def state_dict(self): | |||
states = { | |||
'timestamp_path': str(self.timestamp_path), | |||
} | |||
return states | |||
def load_state_dict(self, states): | |||
timestamp_path = states['timestamp_path'] | |||
if not os.path.exists(timestamp_path): | |||
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to " | |||
f" {self.timestamp_path.absolute()}.") | |||
else: | |||
logger.info(f"Resume to save checkpoint in path: {timestamp_path}.") | |||
self.timestamp_path = Path(timestamp_path) | |||
def __str__(self): | |||
return 'saver' # saver是无状态的,不需要有特定名字 | |||
class TopkQueue: | |||
def __init__(self, topk): | |||
""" | |||
用于维护处于 topk 的 key, value 对。 | |||
:param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。 | |||
""" | |||
assert isinstance(topk, int) | |||
self.topk = topk | |||
self.topk_dict = {} # 其中 key 为保存的 | |||
def push(self, key, value) -> Optional[Tuple[str, float]]: | |||
""" | |||
将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给 | |||
挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 | |||
推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。 | |||
:param str key: | |||
:param float value: 如果为 None, 则不做任何操作。 | |||
:return: (1)返回输入的 (key, value) ,说明不满足 topk; (2) 返回(None, None),说明满足 topk 且没有被挤出过去的记录; (3) | |||
返回非输入的 (key, value) , 说明输入满足 topk,且挤出了之前的记录。 | |||
""" | |||
if value is None: | |||
return key, value | |||
if self.topk < 0: | |||
return None, None | |||
if self.topk == 0: | |||
return key, value | |||
if len(self.topk_dict)<self.topk: | |||
self.topk_dict[key] = value | |||
return None, None | |||
min_key = min(self.topk_dict, key=lambda x:self.topk_dict[x]) | |||
if self.topk_dict[min_key] > value: | |||
return key, value | |||
else: | |||
min_value = self.topk_dict.pop(min_key) | |||
self.topk_dict[key] = value | |||
return min_key, min_value | |||
def state_dict(self): | |||
return deepcopy(self.topk_dict) | |||
def load_state_dict(self, states): | |||
self.topk_dict.update(states) | |||
def __str__(self): | |||
return f'topk-{self.topk}' | |||
def __bool__(self): | |||
# 仅当 topk 为 0 时,表明该 topk_queue 无意义。 | |||
return self.topk != 0 | |||
class TopkSaver(MonitorUtility, Saver): | |||
def __init__(self, topk, monitor, larger_better, folder, only_state_dict, | |||
model_save_fn, save_evaluate_results, | |||
save_object, **kwargs): | |||
""" | |||
用来保存识别 tokp 模型并保存。 | |||
:param topk: | |||
:param monitor: | |||
:param larger_better: | |||
:param folder: | |||
:param only_state_dict: | |||
:param model_save_fn: | |||
:param save_evaluate_results: | |||
:param save_object: | |||
:param kwargs: | |||
""" | |||
MonitorUtility.__init__(self, monitor, larger_better) | |||
Saver.__init__(self, folder, only_state_dict, model_save_fn, **kwargs) | |||
if monitor is not None and topk == 0: | |||
raise RuntimeError("`monitor` is set, but `topk` is 0.") | |||
if topk != 0 and monitor is None: | |||
raise RuntimeError("`topk` is set, but `monitor` is None.") | |||
assert save_object in ['trainer', 'model'] | |||
self.saver = Saver(folder, only_state_dict, model_save_fn, **kwargs) | |||
self.topk_queue = TopkQueue(topk) | |||
self.save_evaluate_results = save_evaluate_results | |||
self.save_object = save_object | |||
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' | |||
@rank_zero_call | |||
def save_topk(self, trainer, results: Dict) -> Optional[str]: | |||
""" | |||
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。 | |||
:param trainer: | |||
:param results: | |||
:return: | |||
""" | |||
if self.monitor is not None and self.topk_queue: | |||
monitor_value = self.get_monitor_value(results) | |||
if monitor_value is None: | |||
return | |||
key = f"{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||
f"-{self.monitor_name}_{monitor_value}" | |||
pop_key, pop_value = self.topk_queue.push(key, monitor_value if self.larger_better else -monitor_value) | |||
if pop_key == key: # 说明不足以构成 topk,被退回了 | |||
return None | |||
folder = self.save(trainer, key) | |||
if self.save_evaluate_results and folder: | |||
try: | |||
self.save_json(self.itemize_results(results), | |||
os.path.join(folder, FASTNLP_EVALUATE_RESULT_FILENAME)) | |||
except: | |||
logger.exception(f"Fail to save evaluate results to {folder}") | |||
if pop_key and pop_key != key: # 说明需要移除之前的 topk | |||
self.rm(pop_key) | |||
return folder | |||
def save(self, trainer, folder_name): | |||
fn = getattr(trainer, self.save_fn_name) | |||
return super().save(fn, folder_name) | |||
def state_dict(self): | |||
states = { | |||
'topk_queue': self.topk_queue.state_dict(), | |||
'saver': self.saver.state_dict() | |||
} | |||
if isinstance(self._real_monitor, str): | |||
states['_real_monitor'] = self._real_monitor | |||
return states | |||
def load_state_dict(self, states): | |||
topk_queue_states = states['topk_queue'] | |||
saver_states = states['saver'] | |||
self.topk_queue.load_state_dict(topk_queue_states) | |||
self.saver.load_state_dict(saver_states) | |||
if '_real_monitor' in states: | |||
self._real_monitor = states["_real_monitor"] | |||
def __str__(self): | |||
return f'topk-{self.topk_queue}#saver-{self.saver}#save_object-{self.save_object}' |
@@ -1,4 +1,6 @@ | |||
from typing import Optional, Union | |||
import os | |||
from fastNLP.core.log.logger import logger | |||
from difflib import SequenceMatcher | |||
from fastNLP.core.utils.utils import _get_fun_msg | |||
@@ -15,7 +17,7 @@ def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str | |||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有 | |||
找到对应的 monitor | |||
""" | |||
if len(res)==0: | |||
if len(res) == 0 or monitor is None: | |||
return monitor, None | |||
if callable(monitor): | |||
@@ -56,4 +58,3 @@ def _match_length(a:str, b:str)->int: | |||
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | |||
return match.size | |||
@@ -38,7 +38,7 @@ class Evaluator: | |||
driver: Union[str, Driver] = 'torch', | |||
device: Optional[Union[int, List[int], str]] = None, | |||
batch_step_fn: Optional[callable] = None, | |||
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
evaluate_fn: Optional[str] = None, | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
model_wo_auto_param_call: bool = False, | |||
@@ -57,8 +57,9 @@ class Evaluator: | |||
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | |||
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | |||
batch_step_fn 函数。 | |||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`; | |||
默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; | |||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | |||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | |||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | |||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | |||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
@@ -69,6 +70,7 @@ class Evaluator: | |||
:param kwargs: | |||
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | |||
与 batch normalization 将会关闭。默认为True。 | |||
TODO 还没完成。 | |||
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | |||
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | |||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | |||
@@ -119,10 +121,6 @@ class Evaluator: | |||
self._metric_wrapper = None | |||
_ = self.metrics_wrapper # 触发检查 | |||
if self._dist_sampler is not None and not self.driver.is_distributed(): | |||
logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause " | |||
"different process evaluating on different data.") | |||
if evaluate_fn is not None and not isinstance(evaluate_fn, str): | |||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.") | |||
self._evaluate_step, self._evaluate_step_signature_fn = \ | |||
@@ -14,10 +14,10 @@ __all__ = [ | |||
from .loops import Loop, TrainBatchLoop | |||
from .utils import State, TrainerState | |||
from .utils.utils import check_validate_every | |||
from .utils.utils import check_evaluate_every | |||
from .evaluator import Evaluator | |||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | |||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | |||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | |||
from fastNLP.core.callbacks.callback import _CallbackWrapper | |||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
from fastNLP.core.drivers import Driver | |||
@@ -26,7 +26,7 @@ from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nu | |||
from fastNLP.core.utils.utils import _check_valid_parameters_number | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
@@ -94,9 +94,9 @@ class Trainer(TrainerEventTrigger): | |||
evaluate_step 这个函数,如果没有则使用 forward 函数。 | |||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | |||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 | |||
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次; | |||
为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 | |||
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | |||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | |||
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | |||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | |||
@@ -124,7 +124,7 @@ class Trainer(TrainerEventTrigger): | |||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | |||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||
@@ -214,13 +214,13 @@ class Trainer(TrainerEventTrigger): | |||
""" 设置内部的 Evaluator """ | |||
if metrics is None and evaluate_dataloaders is not None: | |||
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") | |||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||
if metrics is not None and evaluate_dataloaders is None: | |||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") | |||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | |||
self.metrics = metrics | |||
self.validate_every = evaluate_every | |||
self.evaluate_every = evaluate_every | |||
self.driver.setup() | |||
self.driver.barrier() | |||
@@ -235,7 +235,7 @@ class Trainer(TrainerEventTrigger): | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if metrics is not None and evaluate_dataloaders is not None: | |||
check_validate_every(evaluate_every) | |||
check_evaluate_every(evaluate_every) | |||
self.evaluator = Evaluator( | |||
model=model, | |||
dataloaders=evaluate_dataloaders, | |||
@@ -248,7 +248,7 @@ class Trainer(TrainerEventTrigger): | |||
output_mapping=output_mapping, | |||
fp16=fp16, | |||
verbose=0, | |||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", None), | |||
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), | |||
progress_bar=kwargs.get('progress_bar', 'auto') | |||
) | |||
@@ -261,11 +261,14 @@ class Trainer(TrainerEventTrigger): | |||
self.driver.set_deterministic_dataloader(self.dataloader) | |||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||
reproducible=self.callback_manager.has_trainer_checkpoint) | |||
reproducible=self.callback_manager._need_reproducible_sampler) | |||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | |||
self.on_after_trainer_initialized(self.driver) | |||
self.evaluate_batch_step_fn = evaluate_batch_step_fn | |||
self.kwargs = kwargs | |||
self.on_after_trainer_initialized(self.driver) | |||
self.driver.barrier() | |||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | |||
@@ -364,10 +367,10 @@ class Trainer(TrainerEventTrigger): | |||
:return: | |||
""" | |||
if self.evaluator is not None: | |||
if callable(self.validate_every): | |||
if self.validate_every(self): | |||
if callable(self.evaluate_every): | |||
if self.evaluate_every(self): | |||
self.run_evaluate() | |||
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0: | |||
elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: | |||
self.run_evaluate() | |||
def epoch_validate(self): | |||
@@ -377,8 +380,8 @@ class Trainer(TrainerEventTrigger): | |||
:return: | |||
""" | |||
if self.evaluator is not None: | |||
if isinstance(self.validate_every, int) and self.validate_every < 0: | |||
validate_every = -self.validate_every | |||
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | |||
validate_every = -self.evaluate_every | |||
if self.cur_epoch_idx % validate_every == 0: | |||
self.run_evaluate() | |||
@@ -427,7 +430,7 @@ class Trainer(TrainerEventTrigger): | |||
self._custom_callbacks[None] = [] | |||
if self.marker is not None: | |||
if len(self._custom_callbacks[self.marker]) == 0: | |||
print(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " | |||
logger.info(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " | |||
f"`{self.marker}` that is added through function `Trainer.on`") | |||
_own_callbacks += self._custom_callbacks[self.marker] | |||
for each_callback in _own_callbacks: | |||
@@ -528,10 +531,10 @@ class Trainer(TrainerEventTrigger): | |||
r""" | |||
用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现; | |||
:param folder: 保存模型的地址; | |||
:param only_state_dict: 是否只保存模型的 `state_dict`; | |||
:param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则在这个文件夹下创建 fastnlp_model.pkl.tar 文件。 | |||
:param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 `state_dict`; | |||
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | |||
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它; | |||
:param kwargs: | |||
""" | |||
self.on_save_model() | |||
@@ -568,14 +571,19 @@ class Trainer(TrainerEventTrigger): | |||
self.on_load_model() | |||
self.driver.barrier() | |||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | |||
if model_load_fn is not None: | |||
if not callable(model_load_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||
rank_zero_call(model_load_fn)(folder) | |||
else: | |||
if isinstance(folder, str): | |||
folder = Path(folder) | |||
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||
try: | |||
if model_load_fn is not None: | |||
if not callable(model_load_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||
rank_zero_call(model_load_fn)(folder) | |||
else: | |||
if isinstance(folder, str): | |||
folder = Path(folder) | |||
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||
except FileNotFoundError as e: | |||
if FASTNLP_MODEL_FILENAME not in os.listdir(folder): | |||
logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.") | |||
raise e | |||
else: | |||
if model_load_fn is not None: | |||
raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | |||
@@ -585,11 +593,13 @@ class Trainer(TrainerEventTrigger): | |||
def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | |||
r""" | |||
用于断点重训 Trainer 的保存函数; | |||
用于断点重训 Trainer 的保存函数。 | |||
:param folder: | |||
:param only_state_dict: | |||
:param model_save_fn: | |||
:param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。 | |||
如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件。 | |||
:param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重。 | |||
:param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder | |||
参数),不必返回任何东西。 | |||
:param kwargs: | |||
:return: | |||
""" | |||
@@ -602,17 +612,6 @@ class Trainer(TrainerEventTrigger): | |||
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | |||
} | |||
# 3. validate filter state; | |||
if self.evaluator is not None: | |||
val_filter_state = {} | |||
if hasattr(self.step_validate, "__fastNLP_filter__"): | |||
val_filter_state["step_validate"] = self.step_validate.__fastNLP_filter__.state_dict() | |||
if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||
val_filter_state["epoch_validate"] = self.epoch_validate.__fastNLP_filter__.state_dict() | |||
states["val_filter_state"] = val_filter_state | |||
else: | |||
states["val_filter_state"] = None | |||
if isinstance(folder, str): | |||
folder = Path(folder) | |||
@@ -649,32 +648,30 @@ class Trainer(TrainerEventTrigger): | |||
dataloader = self.dataloader | |||
if not resume_training: | |||
dataloader = None | |||
if model_load_fn is not None: | |||
if not callable(model_load_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||
rank_zero_call(model_load_fn)(folder) | |||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||
else: | |||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||
try: | |||
if model_load_fn is not None: | |||
if not callable(model_load_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||
rank_zero_call(model_load_fn)(folder) | |||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||
else: | |||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||
except FileNotFoundError as e: | |||
if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder): | |||
logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.") | |||
elif FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder): | |||
logger.error(f"fastNLP Trainer checkpoint file:{FASTNLP_CHECKPOINT_FILENAME} is not found in {folder}.") | |||
raise e | |||
if not resume_training: | |||
return | |||
self.dataloader = states.pop('dataloader') | |||
# 2. validate filter state; | |||
if self.evaluator is not None: | |||
val_filter_state = states["val_filter_state"] | |||
if hasattr(self.step_validate, "__fastNLP_filter__"): | |||
self.step_validate.__fastNLP_filter__.load_state_dict(val_filter_state["step_validate"]) | |||
if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||
self.epoch_validate.__fastNLP_filter__.load_state_dict(val_filter_state["epoch_validate"]) | |||
# 3. 恢复 trainer_state 的状态; | |||
# 1. 恢复 trainer_state 的状态; | |||
self.trainer_state.load_state_dict(states["trainer_state"]) | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
# 2. 修改 trainer_state.batch_idx_in_epoch | |||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
@@ -126,7 +126,7 @@ class _TruncatedDataLoader: | |||
return getattr(self.dataloader, item) | |||
def check_validate_every(validate_every): | |||
def check_evaluate_every(validate_every): | |||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | |||
if callable(validate_every): | |||
@@ -11,6 +11,7 @@ from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
@@ -48,8 +49,8 @@ class TorchDataLoader(DataLoader): | |||
""" | |||
def __init__(self, dataset, batch_size: int = 1, | |||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
@@ -380,7 +380,6 @@ class Driver(ABC): | |||
""" | |||
# 单卡 driver 不需要这个函数; | |||
if self._pids is not None: | |||
exc_type, exc_value, exc_traceback_obj = sys.exc_info() | |||
_write_exc_info = { | |||
'exc_type': str(exc_type.__name__), | |||
@@ -526,7 +526,7 @@ class TorchDDPDriver(TorchDriver): | |||
def barrier(self): | |||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | |||
torch.distributed.barrier(async_op=True) | |||
torch.distributed.barrier(async_op=False) | |||
def is_distributed(self): | |||
return True | |||
@@ -9,8 +9,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from pathlib import Path | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch.utils.data import DataLoader, IterableDataset, RandomSampler, Sampler, BatchSampler, Dataset | |||
from torch.utils.data import DataLoader, IterableDataset, Sampler, BatchSampler, Dataset | |||
from torch.optim import Optimizer | |||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||
_reduces = { | |||
'sum': torch.max, | |||
'min': torch.min, | |||
@@ -30,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler | |||
class TorchDriver(Driver): | |||
@@ -211,8 +212,8 @@ class TorchDriver(Driver): | |||
states['sampler_states'] = sampler_states | |||
else: | |||
raise RuntimeError( | |||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | |||
'state.') | |||
# 2. 保存模型的状态; | |||
if should_save_model: | |||
@@ -283,6 +284,9 @@ class TorchDriver(Driver): | |||
sampler = dataloader_args.batch_sampler | |||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
sampler = dataloader_args.sampler | |||
elif isinstance(dataloader_args.sampler, TorchRandomSampler): | |||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | |||
"`ReproducibleSampler`.") | |||
@@ -19,7 +19,7 @@ class Accuracy(Metric): | |||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||
""" | |||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
@@ -84,6 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
:param rank: | |||
:return: | |||
""" | |||
assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||
f"number of samples({len(self.dataset)})." | |||
assert num_replicas>0 and isinstance(num_replicas, int) | |||
assert isinstance(rank, int) and 0<=rank<num_replicas | |||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||
@@ -24,8 +24,8 @@ __all__ = [ | |||
'indice_collate_wrapper', | |||
'deprecated', | |||
'seq_len_to_mask', | |||
'synchronize_safe_rm', | |||
'synchronize_mkdir' | |||
'rank_zero_rm', | |||
'rank_zero_mkdir' | |||
] | |||
from .cache_results import cache_results | |||
@@ -37,6 +37,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device | |||
from .torch_utils import torch_move_data_to_device | |||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | |||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||
@@ -38,8 +38,8 @@ __all__ = [ | |||
'indice_collate_wrapper', | |||
'deprecated', | |||
'seq_len_to_mask', | |||
'synchronize_safe_rm', | |||
'synchronize_mkdir' | |||
'rank_zero_rm', | |||
'rank_zero_mkdir' | |||
] | |||
@@ -629,7 +629,7 @@ def wait_filepath(path, exist=True): | |||
def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||
def rank_zero_rm(path: Optional[Union[str, Path]]): | |||
""" | |||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||
@@ -638,15 +638,14 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||
:param path: | |||
:return: | |||
""" | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
if not path.exists(): | |||
return | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
if not path.exists(): | |||
return | |||
_recursive_rm(path) | |||
wait_filepath(path, exist=False) | |||
def _recursive_rm(path: Path): | |||
@@ -662,21 +661,19 @@ def _recursive_rm(path: Path): | |||
path.rmdir() | |||
def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||
def rank_zero_mkdir(path: Optional[Union[str, Path]]): | |||
""" | |||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | |||
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||
""" | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
path.mkdir(parents=True, exist_ok=True) | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
wait_filepath(path, exist=True) | |||
path.mkdir(parents=True, exist_ok=True) | |||
def get_class_that_defined_method(method): | |||
@@ -49,7 +49,7 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | |||
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | |||
# todo 注释 直接使用的变量 | |||
# 保存各种内容时的默认名称 | |||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | |||
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | |||
FASTNLP_EVALUATE_RESULT_FILENAME = 'fastnlp_evaluate_results.json' |
@@ -7,13 +7,14 @@ from torch.optim import SGD | |||
import torch.distributed as dist | |||
from pathlib import Path | |||
import re | |||
import time | |||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core import rank_zero_rm | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||
from torchmetrics import Accuracy | |||
@@ -80,44 +81,21 @@ def test_model_checkpoint_callback_1( | |||
version, | |||
only_state_dict | |||
): | |||
# def test_model_checkpoint_callback_1( | |||
# model_and_optimizers: TrainerParameters, | |||
# driver='torch_ddp', | |||
# device=[0, 1], | |||
# version=1, | |||
# only_state_dict=True | |||
# ): | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
try: | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
if version == 0: | |||
callbacks = [ | |||
ModelCheckpointCallback( | |||
monitor="acc", | |||
save_folder=path, | |||
save_every_n_epochs=1, | |||
save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||
save_topk=None, | |||
save_last=False, | |||
save_on_exception=None, | |||
only_state_dict=only_state_dict | |||
) | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
ModelCheckpointCallback( | |||
monitor="acc", | |||
save_folder=path, | |||
save_every_n_epochs=3, | |||
save_every_n_batches=None, | |||
save_topk=2, | |||
save_last=True, | |||
save_on_exception=None, | |||
only_state_dict=only_state_dict | |||
) | |||
] | |||
if version == 0: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=1, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||
monitor=None, only_state_dict=only_state_dict, save_object='model') | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=3, every_n_batches=None, last=True, on_exceptions=None, topk=2, | |||
monitor="acc", only_state_dict=only_state_dict, save_object='model') | |||
] | |||
try: | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
@@ -134,7 +112,7 @@ def test_model_checkpoint_callback_1( | |||
) | |||
trainer.run() | |||
print("Finish train") | |||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
# 检查生成保存模型文件的数量是不是正确的; | |||
if version == 0: | |||
@@ -217,8 +195,7 @@ def test_model_checkpoint_callback_1( | |||
trainer.run() | |||
finally: | |||
synchronize_safe_rm(path) | |||
pass | |||
rank_zero_rm(path) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -233,30 +210,23 @@ def test_model_checkpoint_callback_2( | |||
device, | |||
only_state_dict | |||
): | |||
path = Path.cwd().joinpath("test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
try: | |||
path = Path.cwd().joinpath("test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
from fastNLP.core.callbacks.callback_events import Events | |||
@Trainer.on(Events.on_train_epoch_end) | |||
def raise_exception(trainer): | |||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||
raise NotImplementedError | |||
callbacks = [ | |||
ModelCheckpointCallback( | |||
monitor="acc1", | |||
save_folder=path, | |||
save_every_n_epochs=None, | |||
save_every_n_batches=None, | |||
save_topk=None, | |||
save_last=False, | |||
save_on_exception=NotImplementedError, | |||
only_state_dict=only_state_dict | |||
), | |||
] | |||
from fastNLP.core.callbacks.callback_events import Events | |||
@Trainer.on(Events.on_train_epoch_end) | |||
def raise_exception(trainer): | |||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||
raise NotImplementedError | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False, | |||
on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict, | |||
save_object='model'), | |||
] | |||
try: | |||
with pytest.raises(NotImplementedError): | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
@@ -315,14 +285,14 @@ def test_model_checkpoint_callback_2( | |||
trainer.run() | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
# pass | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||
@magic_argv_env_context | |||
@@ -333,37 +303,21 @@ def test_trainer_checkpoint_callback_1( | |||
version, | |||
only_state_dict | |||
): | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
try: | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
if version == 0: | |||
callbacks = [ | |||
TrainerCheckpointCallback( | |||
monitor="acc", | |||
save_folder=path, | |||
save_every_n_epochs=7, | |||
save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||
save_topk=None, | |||
save_last=False, | |||
save_on_exception=None, | |||
only_state_dict=only_state_dict | |||
) | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
TrainerCheckpointCallback( | |||
monitor="acc", | |||
save_folder=path, | |||
save_every_n_epochs=None, | |||
save_every_n_batches=None, | |||
save_topk=2, | |||
save_last=True, | |||
save_on_exception=None, | |||
only_state_dict=only_state_dict | |||
) | |||
] | |||
if version == 0: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=7, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||
monitor=None, only_state_dict=only_state_dict, save_object='trainer') | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=True, on_exceptions=None, | |||
topk=2, monitor="acc", only_state_dict=only_state_dict, save_object='trainer') | |||
] | |||
try: | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
@@ -461,8 +415,7 @@ def test_trainer_checkpoint_callback_1( | |||
trainer.run() | |||
finally: | |||
synchronize_safe_rm(path) | |||
pass | |||
rank_zero_rm(path) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -594,12 +547,12 @@ def test_trainer_checkpoint_callback_2( | |||
callbacks = [ | |||
TrainerCheckpointCallback( | |||
monitor="acc", | |||
save_folder=path, | |||
save_every_n_epochs=None, | |||
save_every_n_batches=50, | |||
save_topk=None, | |||
save_last=False, | |||
save_on_exception=None, | |||
folder=path, | |||
every_n_epochs=None, | |||
every_n_batches=50, | |||
topk=None, | |||
last=False, | |||
on_exception=None, | |||
model_save_fn=model_save_fn | |||
) | |||
] | |||
@@ -607,12 +560,12 @@ def test_trainer_checkpoint_callback_2( | |||
callbacks = [ | |||
TrainerCheckpointCallback( | |||
monitor="acc", | |||
save_folder=path, | |||
save_every_n_epochs=None, | |||
save_every_n_batches=None, | |||
save_topk=1, | |||
save_last=True, | |||
save_on_exception=None, | |||
folder=path, | |||
every_n_epochs=None, | |||
every_n_batches=None, | |||
topk=1, | |||
last=True, | |||
on_exception=None, | |||
model_save_fn=model_save_fn | |||
) | |||
] | |||
@@ -710,7 +663,7 @@ def test_trainer_checkpoint_callback_2( | |||
trainer.run() | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
# pass | |||
if dist.is_initialized(): | |||
@@ -0,0 +1,263 @@ | |||
""" | |||
测试 more_evaluate_callback | |||
(1)能不能正确 evaluate ; | |||
(2) 能不能保存 topk 并load进来进行训练 | |||
""" | |||
import pytest | |||
import os | |||
import pytest | |||
from typing import Any | |||
from dataclasses import dataclass | |||
from torch.utils.data import DataLoader | |||
from torch.optim import SGD | |||
import torch.distributed as dist | |||
from pathlib import Path | |||
import re | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import rank_zero_rm | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||
from torchmetrics import Accuracy | |||
from fastNLP.core.metrics import Metric | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.callbacks import MoreEvaluateCallback | |||
@dataclass | |||
class ArgMaxDatasetConfig: | |||
num_labels: int = 10 | |||
feature_dimension: int = 10 | |||
data_num: int = 100 | |||
seed: int = 0 | |||
batch_size: int = 4 | |||
shuffle: bool = True | |||
@dataclass | |||
class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
more_metrics: Any = None | |||
@pytest.fixture(scope="module", params=[0], autouse=True) | |||
def model_and_optimizers(request): | |||
trainer_params = TrainerParameters() | |||
trainer_params.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||
dataset = TorchArgMaxDatset( | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||
data_num=ArgMaxDatasetConfig.data_num, | |||
seed=ArgMaxDatasetConfig.seed | |||
) | |||
_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_size=ArgMaxDatasetConfig.batch_size, | |||
shuffle=True | |||
) | |||
class LossMetric(Metric): | |||
def __init__(self): | |||
super().__init__() | |||
self.register_element('loss') | |||
def update(self, loss): | |||
self.loss += loss.item() | |||
def get_metric(self) -> dict: | |||
return self.loss.item() | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {'loss': LossMetric()} | |||
trainer_params.more_metrics = {"acc": Accuracy()} | |||
return trainer_params | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||
@magic_argv_env_context | |||
def test_model_more_evaluate_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
version, | |||
only_state_dict | |||
): | |||
try: | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
if version == 0: | |||
callbacks = [ | |||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
metrics=model_and_optimizers.more_metrics, | |||
evaluate_every=-1, | |||
folder=path, topk=-1, | |||
topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
metrics=model_and_optimizers.more_metrics, | |||
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||
save_object='model') | |||
] | |||
n_epochs = 5 | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=n_epochs, | |||
callbacks=callbacks, | |||
output_from_new_proc="all", | |||
evaluate_fn='train_step' | |||
) | |||
trainer.run() | |||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
# 检查生成保存模型文件的数量是不是正确的; | |||
if version == 0: | |||
assert len(all_saved_model_paths) == n_epochs | |||
elif version == 1: | |||
assert len(all_saved_model_paths) == 1 | |||
for folder in all_saved_model_paths: | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=2, | |||
output_from_new_proc="all", | |||
evaluate_fn='train_step' | |||
) | |||
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | |||
trainer.load_model(folder, only_state_dict=only_state_dict) | |||
trainer.run() | |||
finally: | |||
rank_zero_rm(path) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||
@magic_argv_env_context | |||
def test_trainer_checkpoint_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
version, | |||
only_state_dict | |||
): | |||
try: | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
if version == 0: | |||
callbacks = [ | |||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
metrics=model_and_optimizers.more_metrics, | |||
evaluate_every=-1, | |||
folder=path, topk=-1, | |||
topk_monitor='acc', only_state_dict=only_state_dict, save_object='trainer') | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
metrics=model_and_optimizers.more_metrics, | |||
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||
save_object='trainer') | |||
] | |||
n_epochs = 5 | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=n_epochs, | |||
callbacks=callbacks, | |||
output_from_new_proc="all", | |||
evaluate_fn='train_step' | |||
) | |||
trainer.run() | |||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
# 检查生成保存模型文件的数量是不是正确的; | |||
if version == 0: | |||
assert len(all_saved_model_paths) == n_epochs | |||
elif version == 1: | |||
assert len(all_saved_model_paths) == 1 | |||
for folder in all_saved_model_paths: | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=7, | |||
output_from_new_proc="all", | |||
evaluate_fn='train_step' | |||
) | |||
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | |||
trainer.load(folder, only_state_dict=only_state_dict) | |||
trainer.run() | |||
finally: | |||
rank_zero_rm(path) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() |
@@ -15,7 +15,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | |||
from tests.helpers.utils import magic_argv_env_context, Capturing | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core import rank_zero_rm | |||
@dataclass | |||
@@ -239,7 +239,7 @@ def test_trainer_output_from_new_proc( | |||
assert err_path.exists() | |||
path = Path(os.path.abspath(output_from_new_proc)) | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||
@@ -11,7 +11,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core import rank_zero_rm | |||
import paddle | |||
from paddle.io import DataLoader, BatchSampler | |||
@@ -578,11 +578,11 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
finally: | |||
if only_state_dict: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
else: | |||
synchronize_safe_rm(path + ".pdiparams") | |||
synchronize_safe_rm(path + ".pdiparams.info") | |||
synchronize_safe_rm(path + ".pdmodel") | |||
rank_zero_rm(path + ".pdiparams") | |||
rank_zero_rm(path + ".pdiparams.info") | |||
rank_zero_rm(path + ".pdmodel") | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
def test_save_and_load_with_randombatchsampler(only_state_dict): | |||
@@ -652,7 +652,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): | |||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
def test_save_and_load_with_randomsampler(only_state_dict): | |||
@@ -730,4 +730,4 @@ def test_save_and_load_with_randomsampler(only_state_dict): | |||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) |
@@ -6,7 +6,7 @@ import logging | |||
import re | |||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core import rank_zero_rm | |||
from fastNLP.core.log.logger import logger | |||
from tests.helpers.utils import magic_argv_env_context, recover_logger | |||
@@ -56,7 +56,7 @@ def test_add_file_ddp_1_torch(): | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
synchronize_safe_rm(filepath) | |||
rank_zero_rm(filepath) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@@ -105,7 +105,7 @@ def test_add_file_ddp_2_torch(): | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@@ -155,7 +155,7 @@ def test_add_file_ddp_3_torch(): | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
synchronize_safe_rm(file) | |||
rank_zero_rm(file) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@@ -202,7 +202,7 @@ def test_add_file_ddp_4_torch(): | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@@ -225,7 +225,7 @@ class TestLogger: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
@recover_logger | |||
def test_add_file_2(self): | |||
@@ -243,7 +243,7 @@ class TestLogger: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(origin_path) | |||
rank_zero_rm(origin_path) | |||
@recover_logger | |||
def test_add_file_3(self): | |||
@@ -279,7 +279,7 @@ class TestLogger: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
@recover_logger | |||
def test_stdout(self, capsys): | |||
@@ -8,7 +8,7 @@ import sys | |||
from fastNLP.core.utils.cache_results import cache_results | |||
from tests.helpers.common.utils import check_time_elapse | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core import rank_zero_rm | |||
def get_subprocess_results(cmd): | |||
@@ -56,7 +56,7 @@ class TestCacheResults: | |||
res = demo() | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
rank_zero_rm(cache_fp) | |||
def test_cache_save_refresh(self): | |||
cache_fp = 'demo.pkl' | |||
@@ -70,7 +70,7 @@ class TestCacheResults: | |||
with check_time_elapse(1, op='ge'): | |||
res = demo() | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
rank_zero_rm(cache_fp) | |||
def test_cache_no_func_change(self): | |||
cache_fp = os.path.abspath('demo.pkl') | |||
@@ -91,7 +91,7 @@ class TestCacheResults: | |||
with check_time_elapse(1, op='lt'): | |||
res = demo() | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
rank_zero_rm('demo.pkl') | |||
def test_cache_func_change(self, capsys): | |||
cache_fp = 'demo.pkl' | |||
@@ -121,7 +121,7 @@ class TestCacheResults: | |||
assert 'is different from its last cache' not in output[0] | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
rank_zero_rm('demo.pkl') | |||
def test_cache_check_hash(self): | |||
cache_fp = 'demo.pkl' | |||
@@ -152,7 +152,7 @@ class TestCacheResults: | |||
assert 'is different from its last cache' in output[0] | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
rank_zero_rm('demo.pkl') | |||
# 外部 function 改变也会 导致改变 | |||
def test_refer_fun_change(self): | |||
@@ -177,7 +177,7 @@ class TestCacheResults: | |||
assert 'is different from its last cache' in res | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
rank_zero_rm(cache_fp) | |||
# 外部 method 改变也会 导致改变 | |||
def test_refer_class_method_change(self): | |||
@@ -202,7 +202,7 @@ class TestCacheResults: | |||
assert 'is different from its last cache' in res | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
rank_zero_rm(cache_fp) | |||
def test_duplicate_keyword(self): | |||
with pytest.raises(RuntimeError): | |||
@@ -240,7 +240,7 @@ class TestCacheResults: | |||
results = cache() | |||
assert (1, 2) == results | |||
finally: | |||
synchronize_safe_rm('demo/') | |||
rank_zero_rm('demo/') | |||
def test_result_none_error(self): | |||
@cache_results('demo.pkl') | |||
@@ -251,7 +251,7 @@ class TestCacheResults: | |||
with pytest.raises(RuntimeError): | |||
results = cache() | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
rank_zero_rm('demo.pkl') | |||
if __name__ == '__main__': | |||
@@ -2,7 +2,7 @@ import os | |||
from fastNLP.envs.set_backend import dump_fastnlp_backend | |||
from tests.helpers.utils import Capturing | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core import rank_zero_rm | |||
def test_dump_fastnlp_envs(): | |||
@@ -14,4 +14,4 @@ def test_dump_fastnlp_envs(): | |||
assert filepath in output[0] | |||
assert os.path.exists(filepath) | |||
finally: | |||
synchronize_safe_rm(filepath) | |||
rank_zero_rm(filepath) |
@@ -9,7 +9,7 @@ import numpy as np | |||
from fastNLP.modules.mix_modules.mix_module import MixModule | |||
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core import rank_zero_rm | |||
############################################################################ | |||
@@ -227,7 +227,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||
self.assertDictEqual(state_dict, new_state_dict) | |||
finally: | |||
synchronize_safe_rm(path) | |||
rank_zero_rm(path) | |||
def if_device_correct(self, device): | |||