|
|
@@ -1,8 +1,11 @@ |
|
|
|
__all__ = [ |
|
|
|
'TopkSaver' |
|
|
|
] |
|
|
|
import json |
|
|
|
import os |
|
|
|
from copy import deepcopy |
|
|
|
from pathlib import Path |
|
|
|
from typing import Optional, Dict, Tuple |
|
|
|
from typing import Optional, Dict, Tuple, Callable, Union |
|
|
|
|
|
|
|
from fastNLP.core.utils import rank_zero_rm |
|
|
|
from fastNLP.core.log import logger |
|
|
@@ -13,17 +16,20 @@ from .has_monitor_callback import MonitorUtility |
|
|
|
|
|
|
|
|
|
|
|
class Saver: |
|
|
|
def __init__(self, folder, only_state_dict, model_save_fn, **kwargs): |
|
|
|
def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True, |
|
|
|
model_save_fn:Callable=None, **kwargs): |
|
|
|
""" |
|
|
|
执行保存的对象。保存的文件组织结构为 |
|
|
|
- folder # 当前初始化的参数 |
|
|
|
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 |
|
|
|
- folder_name # 由 save() 调用时传入。 |
|
|
|
|
|
|
|
:param folder: |
|
|
|
:param only_state_dict: |
|
|
|
:param model_save_fn: |
|
|
|
:param kwargs: |
|
|
|
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 |
|
|
|
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 |
|
|
|
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 |
|
|
|
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 |
|
|
|
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 |
|
|
|
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 |
|
|
|
""" |
|
|
|
if folder is None: |
|
|
|
logger.warning( |
|
|
@@ -39,21 +45,26 @@ class Saver: |
|
|
|
self.only_state_dict = only_state_dict |
|
|
|
self.model_save_fn = model_save_fn |
|
|
|
self.kwargs = kwargs |
|
|
|
self.eval_results = kwargs.get('eval_results', True) |
|
|
|
self.save_object = save_object |
|
|
|
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' |
|
|
|
|
|
|
|
self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) |
|
|
|
|
|
|
|
@rank_zero_call |
|
|
|
def save(self, save_fn, folder_name): |
|
|
|
def save(self, trainer, folder_name): |
|
|
|
""" |
|
|
|
执行保存的函数,将数据保存在 folder/timestamp/folder_name 下。其中 folder 为用户在初始化指定, |
|
|
|
timestamp 为当前脚本的启动时间。 |
|
|
|
执行保存的函数,将数据保存在 |
|
|
|
- folder/ |
|
|
|
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 |
|
|
|
- folder_name # 当前函数参数 |
|
|
|
|
|
|
|
:param save_fn: 调用的保存函数,应该可接受参数 folder:str, only_state_dict: bool, model_save_fn: callable, kwargs |
|
|
|
:param trainer: Trainer 对象 |
|
|
|
:param folder_name: 保存的 folder 名称,将被创建。 |
|
|
|
:return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。 |
|
|
|
""" |
|
|
|
folder = self.timestamp_path.joinpath(folder_name) |
|
|
|
folder.mkdir(parents=True, exist_ok=True) |
|
|
|
save_fn = getattr(trainer, self.save_fn_name) |
|
|
|
save_fn( |
|
|
|
folder=folder, |
|
|
|
only_state_dict=self.only_state_dict, |
|
|
@@ -79,7 +90,7 @@ class Saver: |
|
|
|
""" |
|
|
|
移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。 |
|
|
|
|
|
|
|
:param folder_name: |
|
|
|
:param folder_name: 需要移除的路径。 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
folder = self.timestamp_path.joinpath(folder_name) |
|
|
@@ -101,7 +112,7 @@ class Saver: |
|
|
|
self.timestamp_path = Path(timestamp_path) |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return 'saver' # saver是无状态的,不需要有特定名字 |
|
|
|
return f'saver:{self.save_object}' |
|
|
|
|
|
|
|
|
|
|
|
class TopkQueue: |
|
|
@@ -113,9 +124,9 @@ class TopkQueue: |
|
|
|
""" |
|
|
|
assert isinstance(topk, int) |
|
|
|
self.topk = topk |
|
|
|
self.topk_dict = {} # 其中 key 为保存的 |
|
|
|
self.topk_dict = {} # 其中 key 为保存的内容, value 是对应的性能。 |
|
|
|
|
|
|
|
def push(self, key, value) -> Optional[Tuple[str, float]]: |
|
|
|
def push(self, key, value) -> Optional[Tuple[Union[str, None], Union[float, None]]]: |
|
|
|
""" |
|
|
|
将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给 |
|
|
|
挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 |
|
|
@@ -153,50 +164,54 @@ class TopkQueue: |
|
|
|
return f'topk-{self.topk}' |
|
|
|
|
|
|
|
def __bool__(self): |
|
|
|
# 仅当 topk 为 0 时,表明该 topk_queue 无意义。 |
|
|
|
# 当 topk 为 0 时,表明该 topk_queue 无意义。 |
|
|
|
return self.topk != 0 |
|
|
|
|
|
|
|
|
|
|
|
class TopkSaver(MonitorUtility, Saver): |
|
|
|
def __init__(self, topk, monitor, larger_better, folder, only_state_dict, |
|
|
|
model_save_fn, save_evaluate_results, |
|
|
|
save_object, **kwargs): |
|
|
|
""" |
|
|
|
用来保存识别 tokp 模型并保存。 |
|
|
|
|
|
|
|
:param topk: |
|
|
|
:param monitor: |
|
|
|
:param larger_better: |
|
|
|
:param folder: |
|
|
|
:param only_state_dict: |
|
|
|
:param model_save_fn: |
|
|
|
:param save_evaluate_results: |
|
|
|
:param save_object: |
|
|
|
:param kwargs: |
|
|
|
def __init__(self, topk:int, monitor:str, larger_better:bool=True, folder:str=None, save_object:str='model', |
|
|
|
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, |
|
|
|
**kwargs): |
|
|
|
""" |
|
|
|
用来保存识别 topk 模型并保存,也可以仅当一个保存 saver 使用。保存路径为 |
|
|
|
- folder/ |
|
|
|
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 |
|
|
|
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 |
|
|
|
|
|
|
|
:param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 |
|
|
|
:param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 |
|
|
|
最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数, |
|
|
|
接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请 |
|
|
|
返回 None 。 |
|
|
|
:param larger_better: 该 monitor 是否越大越好。 |
|
|
|
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。 |
|
|
|
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 |
|
|
|
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 |
|
|
|
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 |
|
|
|
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 |
|
|
|
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 |
|
|
|
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 |
|
|
|
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 |
|
|
|
""" |
|
|
|
MonitorUtility.__init__(self, monitor, larger_better) |
|
|
|
Saver.__init__(self, folder, only_state_dict, model_save_fn, **kwargs) |
|
|
|
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) |
|
|
|
|
|
|
|
if monitor is not None and topk == 0: |
|
|
|
raise RuntimeError("`monitor` is set, but `topk` is 0.") |
|
|
|
if topk != 0 and monitor is None: |
|
|
|
raise RuntimeError("`topk` is set, but `monitor` is None.") |
|
|
|
|
|
|
|
assert save_object in ['trainer', 'model'] |
|
|
|
|
|
|
|
self.saver = Saver(folder, only_state_dict, model_save_fn, **kwargs) |
|
|
|
self.topk_queue = TopkQueue(topk) |
|
|
|
self.save_evaluate_results = save_evaluate_results |
|
|
|
self.save_object = save_object |
|
|
|
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model' |
|
|
|
|
|
|
|
@rank_zero_call |
|
|
|
def save_topk(self, trainer, results: Dict) -> Optional[str]: |
|
|
|
""" |
|
|
|
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。 |
|
|
|
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 None ,则说明此次没有满足 |
|
|
|
topk 要求,没有发生保存。 |
|
|
|
|
|
|
|
:param trainer: |
|
|
|
:param results: |
|
|
|
:param results: evaluate 的结果。 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if self.monitor is not None and self.topk_queue: |
|
|
@@ -220,14 +235,10 @@ class TopkSaver(MonitorUtility, Saver): |
|
|
|
self.rm(pop_key) |
|
|
|
return folder |
|
|
|
|
|
|
|
def save(self, trainer, folder_name): |
|
|
|
fn = getattr(trainer, self.save_fn_name) |
|
|
|
return super().save(fn, folder_name) |
|
|
|
|
|
|
|
def state_dict(self): |
|
|
|
states = { |
|
|
|
'topk_queue': self.topk_queue.state_dict(), |
|
|
|
'saver': self.saver.state_dict() |
|
|
|
'timestamp_path': str(self.timestamp_path), |
|
|
|
} |
|
|
|
if isinstance(self._real_monitor, str): |
|
|
|
states['_real_monitor'] = self._real_monitor |
|
|
@@ -236,11 +247,18 @@ class TopkSaver(MonitorUtility, Saver): |
|
|
|
|
|
|
|
def load_state_dict(self, states): |
|
|
|
topk_queue_states = states['topk_queue'] |
|
|
|
saver_states = states['saver'] |
|
|
|
self.topk_queue.load_state_dict(topk_queue_states) |
|
|
|
self.saver.load_state_dict(saver_states) |
|
|
|
|
|
|
|
timestamp_path = states['timestamp_path'] |
|
|
|
if not os.path.exists(timestamp_path): |
|
|
|
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to " |
|
|
|
f" {self.timestamp_path.absolute()}.") |
|
|
|
else: |
|
|
|
logger.info(f"Resume to save checkpoint in path: {timestamp_path}.") |
|
|
|
self.timestamp_path = Path(timestamp_path) |
|
|
|
|
|
|
|
if '_real_monitor' in states: |
|
|
|
self._real_monitor = states["_real_monitor"] |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return f'topk-{self.topk_queue}#saver-{self.saver}#save_object-{self.save_object}' |
|
|
|
return f'topk-{self.topk_queue}#save_object-{self.save_object}' |