Browse Source

更新Trainer的部分文档; 更新topk_saver文档

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
ceb30937b8
2 changed files with 67 additions and 48 deletions
  1. +65
    -47
      fastNLP/core/callbacks/topk_saver.py
  2. +2
    -1
      fastNLP/core/controllers/trainer.py

+ 65
- 47
fastNLP/core/callbacks/topk_saver.py View File

@@ -1,8 +1,11 @@
__all__ = [
'TopkSaver'
]
import json import json
import os import os
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Tuple
from typing import Optional, Dict, Tuple, Callable, Union


from fastNLP.core.utils import rank_zero_rm from fastNLP.core.utils import rank_zero_rm
from fastNLP.core.log import logger from fastNLP.core.log import logger
@@ -13,17 +16,20 @@ from .has_monitor_callback import MonitorUtility




class Saver: class Saver:
def __init__(self, folder, only_state_dict, model_save_fn, **kwargs):
def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True,
model_save_fn:Callable=None, **kwargs):
""" """
执行保存的对象。保存的文件组织结构为 执行保存的对象。保存的文件组织结构为
- folder # 当前初始化的参数 - folder # 当前初始化的参数
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- folder_name # 由 save() 调用时传入。 - folder_name # 由 save() 调用时传入。


:param folder:
:param only_state_dict:
:param model_save_fn:
:param kwargs:
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。
""" """
if folder is None: if folder is None:
logger.warning( logger.warning(
@@ -39,21 +45,26 @@ class Saver:
self.only_state_dict = only_state_dict self.only_state_dict = only_state_dict
self.model_save_fn = model_save_fn self.model_save_fn = model_save_fn
self.kwargs = kwargs self.kwargs = kwargs
self.eval_results = kwargs.get('eval_results', True)
self.save_object = save_object
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model'

self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])


@rank_zero_call @rank_zero_call
def save(self, save_fn, folder_name):
def save(self, trainer, folder_name):
""" """
执行保存的函数,将数据保存在 folder/timestamp/folder_name 下。其中 folder 为用户在初始化指定,
timestamp 为当前脚本的启动时间。
执行保存的函数,将数据保存在
- folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- folder_name # 当前函数参数


:param save_fn: 调用的保存函数,应该可接受参数 folder:str, only_state_dict: bool, model_save_fn: callable, kwargs
:param trainer: Trainer 对象
:param folder_name: 保存的 folder 名称,将被创建。 :param folder_name: 保存的 folder 名称,将被创建。
:return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。 :return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。
""" """
folder = self.timestamp_path.joinpath(folder_name) folder = self.timestamp_path.joinpath(folder_name)
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)
save_fn = getattr(trainer, self.save_fn_name)
save_fn( save_fn(
folder=folder, folder=folder,
only_state_dict=self.only_state_dict, only_state_dict=self.only_state_dict,
@@ -79,7 +90,7 @@ class Saver:
""" """
移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。 移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。


:param folder_name:
:param folder_name: 需要移除的路径。
:return: :return:
""" """
folder = self.timestamp_path.joinpath(folder_name) folder = self.timestamp_path.joinpath(folder_name)
@@ -101,7 +112,7 @@ class Saver:
self.timestamp_path = Path(timestamp_path) self.timestamp_path = Path(timestamp_path)


def __str__(self): def __str__(self):
return 'saver' # saver是无状态的,不需要有特定名字
return f'saver:{self.save_object}'




class TopkQueue: class TopkQueue:
@@ -113,9 +124,9 @@ class TopkQueue:
""" """
assert isinstance(topk, int) assert isinstance(topk, int)
self.topk = topk self.topk = topk
self.topk_dict = {} # 其中 key 为保存的
self.topk_dict = {} # 其中 key 为保存的内容, value 是对应的性能。


def push(self, key, value) -> Optional[Tuple[str, float]]:
def push(self, key, value) -> Optional[Tuple[Union[str, None], Union[float, None]]]:
""" """
将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给 将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给
挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回
@@ -153,50 +164,54 @@ class TopkQueue:
return f'topk-{self.topk}' return f'topk-{self.topk}'


def __bool__(self): def __bool__(self):
# 当 topk 为 0 时,表明该 topk_queue 无意义。
# 当 topk 为 0 时,表明该 topk_queue 无意义。
return self.topk != 0 return self.topk != 0




class TopkSaver(MonitorUtility, Saver): class TopkSaver(MonitorUtility, Saver):
def __init__(self, topk, monitor, larger_better, folder, only_state_dict,
model_save_fn, save_evaluate_results,
save_object, **kwargs):
"""
用来保存识别 tokp 模型并保存。

:param topk:
:param monitor:
:param larger_better:
:param folder:
:param only_state_dict:
:param model_save_fn:
:param save_evaluate_results:
:param save_object:
:param kwargs:
def __init__(self, topk:int, monitor:str, larger_better:bool=True, folder:str=None, save_object:str='model',
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True,
**kwargs):
"""
用来保存识别 topk 模型并保存,也可以仅当一个保存 saver 使用。保存路径为
- folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名

:param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。
:param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用
最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,
接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请
返回 None 。
:param larger_better: 该 monitor 是否越大越好。
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。
:param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。
""" """
MonitorUtility.__init__(self, monitor, larger_better) MonitorUtility.__init__(self, monitor, larger_better)
Saver.__init__(self, folder, only_state_dict, model_save_fn, **kwargs)
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs)


if monitor is not None and topk == 0: if monitor is not None and topk == 0:
raise RuntimeError("`monitor` is set, but `topk` is 0.") raise RuntimeError("`monitor` is set, but `topk` is 0.")
if topk != 0 and monitor is None: if topk != 0 and monitor is None:
raise RuntimeError("`topk` is set, but `monitor` is None.") raise RuntimeError("`topk` is set, but `monitor` is None.")


assert save_object in ['trainer', 'model']

self.saver = Saver(folder, only_state_dict, model_save_fn, **kwargs)
self.topk_queue = TopkQueue(topk) self.topk_queue = TopkQueue(topk)
self.save_evaluate_results = save_evaluate_results self.save_evaluate_results = save_evaluate_results
self.save_object = save_object
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model'


@rank_zero_call @rank_zero_call
def save_topk(self, trainer, results: Dict) -> Optional[str]: def save_topk(self, trainer, results: Dict) -> Optional[str]:
""" """
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 None ,则说明此次没有满足
topk 要求,没有发生保存。


:param trainer: :param trainer:
:param results:
:param results: evaluate 的结果。
:return: :return:
""" """
if self.monitor is not None and self.topk_queue: if self.monitor is not None and self.topk_queue:
@@ -220,14 +235,10 @@ class TopkSaver(MonitorUtility, Saver):
self.rm(pop_key) self.rm(pop_key)
return folder return folder


def save(self, trainer, folder_name):
fn = getattr(trainer, self.save_fn_name)
return super().save(fn, folder_name)

def state_dict(self): def state_dict(self):
states = { states = {
'topk_queue': self.topk_queue.state_dict(), 'topk_queue': self.topk_queue.state_dict(),
'saver': self.saver.state_dict()
'timestamp_path': str(self.timestamp_path),
} }
if isinstance(self._real_monitor, str): if isinstance(self._real_monitor, str):
states['_real_monitor'] = self._real_monitor states['_real_monitor'] = self._real_monitor
@@ -236,11 +247,18 @@ class TopkSaver(MonitorUtility, Saver):


def load_state_dict(self, states): def load_state_dict(self, states):
topk_queue_states = states['topk_queue'] topk_queue_states = states['topk_queue']
saver_states = states['saver']
self.topk_queue.load_state_dict(topk_queue_states) self.topk_queue.load_state_dict(topk_queue_states)
self.saver.load_state_dict(saver_states)

timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to save checkpoint in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)

if '_real_monitor' in states: if '_real_monitor' in states:
self._real_monitor = states["_real_monitor"] self._real_monitor = states["_real_monitor"]


def __str__(self): def __str__(self):
return f'topk-{self.topk_queue}#saver-{self.saver}#save_object-{self.save_object}'
return f'topk-{self.topk_queue}#save_object-{self.save_object}'

+ 2
- 1
fastNLP/core/controllers/trainer.py View File

@@ -120,7 +120,8 @@ class Trainer(TrainerEventTrigger):
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
注意如果 model_device 为 None,那么 data_device 不会起作用; 注意如果 model_device 为 None,那么 data_device 不会起作用;
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。


Loading…
Cancel
Save