Browse Source

将CheckpointCallback拆分为ModelCheckpointCallback和TrainerCheckpointCallback,修改了部分实现

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
ce835212e6
9 changed files with 280 additions and 185 deletions
  1. +3
    -3
      fastNLP/core/callbacks/callback_manager.py
  2. +183
    -110
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +1
    -1
      fastNLP/core/callbacks/load_best_model_callback.py
  4. +18
    -3
      fastNLP/core/controllers/trainer.py
  5. +3
    -0
      fastNLP/core/drivers/driver.py
  6. +0
    -2
      fastNLP/core/drivers/torch_driver/torch_driver.py
  7. +3
    -0
      fastNLP/core/utils/utils.py
  8. +1
    -1
      fastNLP/envs/set_env_on_import.py
  9. +68
    -65
      tests/core/callbacks/test_checkpoint_callback_torch.py

+ 3
- 3
fastNLP/core/callbacks/callback_manager.py View File

@@ -8,7 +8,7 @@ __all__ = [

from .callback_events import Events
from .callback import Callback
from .checkpoint_callback import CheckpointCallback
from .checkpoint_callback import TrainerCheckpointCallback
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.core.log import logger

@@ -98,7 +98,7 @@ class CallbackManager:
:return:
"""
for each_callback in self.class_callbacks:
if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint:
if isinstance(each_callback, TrainerCheckpointCallback):
self._has_trainer_checkpoint = True
self.dissect_one_callback(each_callback)

@@ -210,7 +210,7 @@ class CallbackManager:
each_callback.on_load_checkpoint(trainer, None)

@property
def has_trainer_chechpoint(self) -> bool:
def has_trainer_checkpoint(self) -> bool:
return self._has_trainer_checkpoint

@_transfer


+ 183
- 110
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -1,12 +1,12 @@
import os
from typing import Union, Optional, Callable, Dict, Sequence
from pathlib import Path
from functools import partial
from time import sleep

__all__ = [
'CheckpointCallback'
]
import os
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping
from pathlib import Path
from abc import ABC
import sys


import fastNLP
from .callback import Callback, Filter
@@ -14,35 +14,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
from fastNLP.core.utils import apply_to_collection


class CheckpointCallback(Callback):
class CanItemDataType(ABC):
"""
1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情;
2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar;
3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1",
"acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例;
4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样;
检测可以进行传输的对象。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented



class CheckpointCallback(Callback):
def __init__(
self,
monitor,
is_trainer_checkpoint: Optional[bool] = False,

save_folder: Optional[Union[str, Path]] = None,

save_every_n_epochs: Optional[int] = None,
save_every_n_global_batches: Optional[int] = None,
save_every_n_batches: Optional[int] = None,
save_last: bool = True,
save_topk: Optional[int] = None,
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,

larger_better: bool = True,
only_state_dict: bool = True,

model_save_fn: Optional[Callable] = None,

**kwargs,
):
if monitor is None and save_topk is not None:
@@ -51,9 +53,6 @@ class CheckpointCallback(Callback):
if monitor is not None and not isinstance(monitor, str):
raise ValueError("Parameter `monitor` should be of 'str' type.")

if not isinstance(is_trainer_checkpoint, bool):
raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.")

if save_folder is None:
logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
@@ -67,15 +66,15 @@ class CheckpointCallback(Callback):
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1:
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.")

# 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的
# 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样;
self._filter_every_n_epochs = Filter(every=save_every_n_epochs)
else:
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除

if save_every_n_global_batches is not None:
if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1:
if save_every_n_batches is not None:
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1:
raise ValueError(
"parameter save_every_n_global_batches should be an int and greater than or equal to 1.")
self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches)
"parameter save_every_n_batches should be an int and greater than or equal to 1.")
else:
save_every_n_batches = sys.maxsize # 使得没有数字可以整除

if save_topk is not None:
if not isinstance(save_topk, int) or save_topk < 1:
@@ -89,12 +88,12 @@ class CheckpointCallback(Callback):
if not issubclass(exception, BaseException):
raise TypeError("Each exception in parameter `save_on_exception` can only be "
"`BaseException` type.")

else:
save_on_exception = []
self.monitor = monitor
self.is_trainer_checkpoint = is_trainer_checkpoint
self.save_folder = Path(save_folder)
self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_global_batches = save_every_n_global_batches
self.save_every_n_batches = save_every_n_batches
self.save_last = save_last
self.save_topk = save_topk
self.larger_better = larger_better
@@ -107,7 +106,7 @@ class CheckpointCallback(Callback):
self._topk_model = {}
self._topn = 0 # 表示目前已经保存了几个最好的模型;

# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为;
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key;
@@ -115,76 +114,83 @@ class CheckpointCallback(Callback):

# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
synchronize_mkdir(self.log_filepath)
synchronize_mkdir(self.timestamp_path)

def on_validate_end(self, trainer, validate_res):
self._save_topk(trainer, validate_res)

def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
self._save_every_n_epochs(trainer)
self._save_last(trainer)
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}'
self.save(trainer, folder_name=folder_name)
if self.save_last:
folder_name = f'{self.folder_prefix}-last'
self.save(trainer, folder_name=folder_name)

def on_train_batch_end(self, trainer):
self._save_every_n_global_batches(trainer)
if trainer.global_forward_batches % self.save_every_n_batches == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}'
self.save(trainer, folder_name=folder_name)

def on_exception(self, trainer, exception: BaseException):
if self.save_on_exception is not None and exception.__class__ in self.save_on_exception:
folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None)
folder = folder + f"_{exception.__class__.__name__}"
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder)
if exception.__class__ in self.save_on_exception:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \
f'exception_{exception.__class__.__name__}'
self.save(trainer=trainer, folder_name=folder_name)

def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
self._get_validate_metric(sanity_check_res)

def on_save_checkpoint(self, trainer) -> Dict:
"""
我们需要保存 CheckpointCallback 内部的几个 filter 的状态;
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。
topk_model的状态
_real_monitor的值
"""

states = {}
if self.save_every_n_epochs is not None:
states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict()
if self.save_every_n_global_batches is not None:
states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict()
states["real_monitor"] = self._real_monitor
states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType,
function=lambda x:x.item())
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
states['_real_monitor'] = self._real_monitor
return states

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
if self.save_every_n_epochs is not None:
self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"])
if self.save_every_n_global_batches is not None:
self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"])
timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to save in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)
_topk_model = states['_topk_model']
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
if save_topk is not None and self.save_topk is not None:
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
f"as {save_topk}."
self._topk_model.update(self._topk_model)
self._real_monitor = states["real_monitor"]

def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"):
if self.save_every_n_epochs is not None:
if self.is_trainer_checkpoint:
_fn_every_n_epochs = trainer.save
else:
_fn_every_n_epochs = trainer.save_model
_fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None)
_fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs)
_fn_every_n_epochs()

def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"):
if self.save_every_n_global_batches is not None:
if self.is_trainer_checkpoint:
_fn_every_n_global_batches = trainer.save
else:
_fn_every_n_global_batches = trainer.save_model
_fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None)
_fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches)
_fn_every_n_global_batches()

def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict):
"""
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。

:param trainer:
:param validate_res:
:return:
"""
if self.save_topk is not None:
_metric_value = self._get_validate_metric(validate_res)
_saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value)
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self._real_monitor}_{_metric_value}"

_should_save = False
if self._topn < self.save_topk:
self._topk_model[_saved_name] = _metric_value
self._topk_model[folder_name] = _metric_value
self._topn += 1
_should_save = True
else:
@@ -192,39 +198,27 @@ class CheckpointCallback(Callback):
key=lambda x: self._topk_model[x])
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]):
self._topk_model[_saved_name] = _metric_value
self._topk_model[folder_name] = _metric_value
_should_save = True
self._topk_model.pop(_least_valuable_model)
synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model))
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model))

assert len(self._topk_model) == self.save_topk == self._topn

if _should_save:
self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name)
self.save(trainer, folder_name=folder_name)

def _save_last(self, trainer: "fastNLP.Trainer"):
if self.save_last:
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last")

def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None,
substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None):
# 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者
# epoch-batch-monitor-monitor_value;
if substitute_folder is None:
folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric))
else:
folder = self.log_filepath.joinpath(substitute_folder)
def save(self, trainer, folder_name):
"""
执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。

:param trainer:
:param folder_name:
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
synchronize_mkdir(folder)

# 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数;
if substitute_fn is not None:
_fn = substitute_fn
else:
if self.is_trainer_checkpoint:
_fn = trainer.save
else:
_fn = trainer.save_model
_fn = getattr(trainer, self.save_fn_name)
_fn(
folder=folder,
only_state_dict=self.only_state_dict,
@@ -243,18 +237,95 @@ class CheckpointCallback(Callback):
self._real_monitor = use_monitor
return value

def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False,
metric: Optional[Union[int, float]] = None) -> str:
@property
def folder_prefix(self):
raise NotImplementedError("The `folder_prefix` is not specified")

@property
def save_fn_name(self):
raise NotImplementedError("The `save_fn_name` is not specified.")


class ModelCheckpointCallback(CheckpointCallback):
"""
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下

- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- model-last/ # 最后一个 epoch 的保存
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名

model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
@property
def save_fn_name(self):
return 'save_model'

@property
def callback_name(self):
"""
获取当前保存模型的真正地名字;
metric 参数仅当 mode 为 'topk' 时起作用;
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态
:return:
"""
cur_epoch_idx = trainer.cur_epoch_idx
global_forward_batches = trainer.global_forward_batches
_other = ""
if topk:
_other = f"_{metric}"
return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}"
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):
return 'model'


class TrainerCheckpointCallback(CheckpointCallback):
"""
保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下

- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- trainer-last/ # 最后一个 epoch 的保存
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名

model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
@property
def save_fn_name(self):
return 'save'

@property
def callback_name(self):
@@ -262,6 +333,8 @@ class CheckpointCallback(Callback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return:
"""
return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}"

return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):
return 'trainer'

+ 1
- 1
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback):
请在函数内完成对模型的保存。
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出,
请在函数内完成对模型的加载。
:param delete_after_train: 在加载了最佳模型之后是否删掉模型。
:param delete_after_train: 在训练结束后是否删掉模型。
"""
if model_load_fn is not None:
assert callable(model_load_fn), "`model_load_fn` must be a callable object."


+ 18
- 3
fastNLP/core/controllers/trainer.py View File

@@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger):
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager.has_trainer_chechpoint)
reproducible=self.callback_manager.has_trainer_checkpoint)

self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
self.on_after_trainer_initialized(self.driver)
@@ -509,7 +509,7 @@ class Trainer(TrainerEventTrigger):

:param folder: 保存模型的地址;
:param only_state_dict: 是否只保存模型的 `state_dict`;
:param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它;
"""

@@ -534,7 +534,16 @@ class Trainer(TrainerEventTrigger):

def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False,
model_load_fn: Optional[Callable] = None, **kwargs):
"""
加载模型

:param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时,
直接将该 folder 传递到 model_load_fn 中。
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。
:param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。
:param kwargs:
:return:
"""
self.on_load_model()
self.driver.barrier()
if not isinstance(folder, (io.BytesIO, BinaryIO)):
@@ -555,7 +564,13 @@ class Trainer(TrainerEventTrigger):

def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs):
r"""
用于断点重训的保存函数;
用于断点重训 Trainer 的保存函数;

:param folder:
:param only_state_dict:
:param model_save_fn:
:param kwargs:
:return:
"""
self.driver.barrier()



+ 3
- 0
fastNLP/core/drivers/driver.py View File

@@ -68,9 +68,12 @@ class Driver(ABC):
def set_sampler_epoch(self, dataloader, cur_epoch_idx):
r"""
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的;
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。

:param dataloader: 需要设置 epoch 的 dataloader 。
:param cur_epoch_idx: 当前是第几个 epoch;
"""

@abstractmethod
def train_step(self, batch):
"""


+ 0
- 2
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -143,8 +143,6 @@ class TorchDriver(Driver):

:param filepath: 保存到哪个文件夹;
:param only_state_dict: 是否只保存权重;
:param model_save_fn:

:return:
"""
model = self.unwrap_model()


+ 3
- 0
fastNLP/core/utils/utils.py View File

@@ -44,6 +44,9 @@ __all__ = [
]





def get_fn_arg_names(fn: Callable) -> List[str]:
r"""
返回一个函数的所有参数的名字;


+ 1
- 1
fastNLP/envs/set_env_on_import.py View File

@@ -65,7 +65,7 @@ def set_env_on_import():

# fastNLP 内部使用的一些变量
if FASTNLP_LAUNCH_TIME not in os.environ:
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}"
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}"
os.environ[FASTNLP_LAUNCH_TIME] = cur_time

# 设置对应的值


+ 68
- 65
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -8,7 +8,7 @@ import torch.distributed as dist
from pathlib import Path
import re

from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK

@@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1(
version,
only_state_dict
):
# def test_model_checkpoint_callback_1(
# model_and_optimizers: TrainerParameters,
# driver='torch_ddp',
# device=[0, 1],
# version=1,
# only_state_dict=True
# ):
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

if version == 0:
callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=1,
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None,
save_last=False,
save_on_exception=None,
@@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1(
]
elif version == 1:
callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=3,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=2,
save_last=True,
save_on_exception=None,
@@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1(
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=10,
callbacks=callbacks,
output_from_new_proc="all"
@@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1(
if version == 0:

if driver == "torch":
assert "epoch_10-global_batch_250-acc" in all_saved_model_paths
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
assert "model-epoch_10" in all_saved_model_paths
assert "model-epoch_4-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"]
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["model-epoch_10"]
step_save_path = all_saved_model_paths["model-epoch_4-batch_123"]

assert len(all_saved_model_paths) == 12
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "epoch_6-global_batch_78-acc" in all_saved_model_paths
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
assert "model-epoch_6" in all_saved_model_paths
assert "model-epoch_9-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"]
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["model-epoch_6"]
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"]

assert len(all_saved_model_paths) == 11
all_state_dicts = [epoch_save_path, step_save_path]

elif version == 1:

pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")

if driver == "torch":
assert "epoch_9-global_batch_225-acc" in all_saved_model_paths
assert "last" in all_saved_model_paths
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
@@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"]
last_save_path = all_saved_model_paths["last"]
epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 6
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "epoch_9-global_batch_117-acc" in all_saved_model_paths
assert "last" in all_saved_model_paths
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths

aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
@@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"]
last_save_path = all_saved_model_paths["last"]
epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 6
@@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1(

finally:
synchronize_safe_rm(path)
# pass
pass

if dist.is_initialized():
dist.destroy_process_group()
@@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2(
raise NotImplementedError

callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc1",
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=None,
save_last=False,
save_on_exception=NotImplementedError,
@@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2(
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}

if driver == "torch":
assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"]
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"]
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"]
assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"]

assert len(all_saved_model_paths) == 1
all_state_dicts = [exception_model_path]
@@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1(

if version == 0:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=7,
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None,
save_last=False,
save_on_exception=None,
@@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1(
]
elif version == 1:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=2,
save_last=True,
save_on_exception=None,
@@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1(
if version == 0:

if driver == "torch":
assert "epoch_7-global_batch_175-acc" in all_saved_model_paths
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_4-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"]
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"]

assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "epoch_7-global_batch_91-acc" in all_saved_model_paths
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_9-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"]
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"]

assert len(all_saved_model_paths) == 2
all_state_dicts = [epoch_save_path, step_save_path]

elif version == 1:

pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")

# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if driver == "torch":
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
@@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths

aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
@@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 3
@@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2(
device,
version
):
pytest.skip("Skip transformers test for now.")
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

import transformers
import transformers # 版本4.16.2
import torch
from torchmetrics import Accuracy
from transformers import AutoModelForSequenceClassification
@@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2(

if version == 0:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=50,
save_every_n_batches=50,
save_topk=None,
save_last=False,
save_on_exception=None,
@@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2(
]
elif version == 1:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=1,
save_last=True,
save_on_exception=None,
@@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2(
if version == 0:

if driver == "torch":
assert "epoch_1-global_batch_200-acc" in all_saved_model_paths
assert "trainer-epoch_1-batch_200" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"]

assert len(all_saved_model_paths) == 4
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "epoch_1-global_batch_100-acc" in all_saved_model_paths
assert "trainer-epoch_1-batch_100" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"]

assert len(all_saved_model_paths) == 2
all_state_dicts = [epoch_save_path]

elif version == 1:

pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")

# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if driver == "torch":
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
@@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 1

last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 2
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths

aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
@@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 1

last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 2


Loading…
Cancel
Save