Browse Source

1.merge ModelCheckPointCallback和TrainerCheckpointCallback;2.新增MoreEvaluateCallback

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
aa95513055
30 changed files with 1067 additions and 585 deletions
  1. +6
    -4
      fastNLP/core/callbacks/__init__.py
  2. +10
    -1
      fastNLP/core/callbacks/callback.py
  3. +3
    -5
      fastNLP/core/callbacks/callback_manager.py
  4. +99
    -287
      fastNLP/core/callbacks/checkpoint_callback.py
  5. +61
    -36
      fastNLP/core/callbacks/has_monitor_callback.py
  6. +3
    -3
      fastNLP/core/callbacks/load_best_model_callback.py
  7. +174
    -0
      fastNLP/core/callbacks/more_evaluate_callback.py
  8. +4
    -3
      fastNLP/core/callbacks/progress_callback.py
  9. +246
    -0
      fastNLP/core/callbacks/topk_saver.py
  10. +3
    -2
      fastNLP/core/callbacks/utils.py
  11. +5
    -7
      fastNLP/core/controllers/evaluator.py
  12. +61
    -64
      fastNLP/core/controllers/trainer.py
  13. +1
    -1
      fastNLP/core/controllers/utils/utils.py
  14. +3
    -2
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  15. +0
    -1
      fastNLP/core/drivers/driver.py
  16. +1
    -1
      fastNLP/core/drivers/torch_driver/ddp.py
  17. +8
    -4
      fastNLP/core/drivers/torch_driver/torch_driver.py
  18. +1
    -1
      fastNLP/core/metrics/accuracy.py
  19. +2
    -0
      fastNLP/core/samplers/unrepeated_sampler.py
  20. +3
    -3
      fastNLP/core/utils/__init__.py
  21. +15
    -18
      fastNLP/core/utils/utils.py
  22. +2
    -2
      fastNLP/envs/env.py
  23. +62
    -109
      tests/core/callbacks/test_checkpoint_callback_torch.py
  24. +263
    -0
      tests/core/callbacks/test_more_evaluate_callback.py
  25. +2
    -2
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  26. +7
    -7
      tests/core/drivers/paddle_driver/test_single_device.py
  27. +8
    -8
      tests/core/log/test_logger.py
  28. +10
    -10
      tests/core/utils/test_cache_results.py
  29. +2
    -2
      tests/envs/test_set_backend.py
  30. +2
    -2
      tests/modules/mix_modules/test_mix_module.py

+ 6
- 4
fastNLP/core/callbacks/__init__.py View File

@@ -4,8 +4,7 @@ __all__ = [
'EventsList',
'Filter',
'CallbackManager',
'ModelCheckpointCallback',
'TrainerCheckpointCallback',
'CheckpointCallback',
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
@@ -13,18 +12,21 @@ __all__ = [
'LoadBestModelCallback',
"EarlyStopCallback",

'MoreEvaluateCallback',

"TorchWarmupCallback",
"TorchGradClipCallback"
"TorchGradClipCallback",
]


from .callback import Callback
from .callback_events import EventsList, Events, Filter
from .callback_manager import CallbackManager
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from .checkpoint_callback import CheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback
from .torch_callbacks import *
from .more_evaluate_callback import MoreEvaluateCallback


+ 10
- 1
fastNLP/core/callbacks/callback.py View File

@@ -236,7 +236,7 @@ class Callback:
结束 validate 时调用,并把 validate 的结果传入。

:param trainer:
:param results:
:param results: Evaluate 的结果,一般是个 dict 。
:return:
"""
pass
@@ -250,6 +250,15 @@ class Callback:
"""
return self.__class__.__name__

@property
def need_reproducible_sampler(self) -> bool:
"""
当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。

:return:
"""
return False


class _CallbackWrapper(Callback):
"""


+ 3
- 5
fastNLP/core/callbacks/callback_manager.py View File

@@ -8,7 +8,6 @@ __all__ = [

from .callback_events import Events
from .callback import Callback
from .checkpoint_callback import TrainerCheckpointCallback
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.core.log import logger

@@ -45,7 +44,7 @@ class CallbackManager:

:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类;
"""
self._has_trainer_checkpoint = False
self._need_reproducible_sampler = False

_has_progress_callback = False
_callbacks = []
@@ -98,8 +97,7 @@ class CallbackManager:
:return:
"""
for each_callback in self.class_callbacks:
if isinstance(each_callback, TrainerCheckpointCallback):
self._has_trainer_checkpoint = True
self._need_reproducible_sampler |= each_callback.need_reproducible_sampler
self.dissect_one_callback(each_callback)

def dissect_one_callback(self, callback: Callback):
@@ -211,7 +209,7 @@ class CallbackManager:

@property
def has_trainer_checkpoint(self) -> bool:
return self._has_trainer_checkpoint
return self._need_reproducible_sampler

@_transfer
def on_after_trainer_initialized(self, trainer):


+ 99
- 287
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -1,339 +1,151 @@
__all__ = [
'ModelCheckpointCallback',
'TrainerCheckpointCallback'
'CheckpointCallback'
]
import os
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping
from typing import Union, Optional, Callable, Dict, Sequence
from pathlib import Path
import sys
from copy import deepcopy


import fastNLP
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
from .topk_saver import TopkSaver
from .callback import Callback


class CheckpointCallback(HasMonitorCallback):
def __init__(
self,
monitor:Optional[Union[str, Callable]]=None,
save_folder: Optional[Union[str, Path]] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_batches: Optional[int] = None,
save_last: bool = False,
save_topk: Optional[int] = None,
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,
larger_better: bool = True,
only_state_dict: bool = True,
model_save_fn: Optional[Callable] = None,
**kwargs,
):
class CheckpointCallback(Callback):
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None,
every_n_batches: Optional[int] = None, last: bool = False,
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0,
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True,
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model',
save_evaluate_results=True, **kwargs):
"""
请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下

- folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型
- {save_object}-last/ # 最后一个 epoch 的保存
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名

model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param every_n_epochs: 多少个 epoch 保存一次。
:param every_n_batches: 多少个 batch 保存一次。
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param topk: 保存 monitor 结果 topK 个。
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。
:param kwargs:
"""
super().__init__(monitor=monitor, larger_better=larger_better,
must_have_monitor=save_topk is not None)
if save_folder is None:
super().__init__()
if folder is None:
logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
save_folder = Path.cwd()
save_folder = Path(save_folder)
if not save_folder.exists():
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!")
elif save_folder.is_file():
raise ValueError("Parameter `save_folder` should be a directory instead of a file.")

if save_every_n_epochs is not None:
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1:
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.")

"Parameter `folder` is None, and we will use the current work directory to find and load your model.")
folder = Path.cwd()
folder = Path(folder)
if not folder.exists():
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!")
elif folder.is_file():
raise ValueError("Parameter `folder` should be a directory instead of a file.")

if every_n_epochs is not None:
if not isinstance(every_n_epochs, int) or every_n_epochs < 1:
raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.")
else:
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除
every_n_epochs = sys.maxsize # 使得没有数字可以整除

if save_every_n_batches is not None:
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1:
raise ValueError(
"parameter save_every_n_batches should be an int and greater than or equal to 1.")
if every_n_batches is not None:
if not isinstance(every_n_batches, int) or every_n_batches < 1:
raise ValueError("Parameter `every_n_batches` should be an int and greater than or equal to 1.")
else:
save_every_n_batches = sys.maxsize # 使得没有数字可以整除
every_n_batches = sys.maxsize # 使得没有数字可以整除

if save_topk is not None:
if not isinstance(save_topk, int) or save_topk < 1:
raise ValueError("parameter save_topk should be an int and greater than or equal to 1.")
if topk is not None:
if not isinstance(topk, int):
raise ValueError("Parameter `topk` should be an int.")
else:
topk = 0

if save_on_exception is not None:
if not isinstance(save_on_exception, Sequence):
save_on_exception = [save_on_exception]
if on_exceptions is not None:
if not isinstance(on_exceptions, Sequence):
on_exceptions = [on_exceptions]

for exception in save_on_exception:
for exception in on_exceptions:
if not issubclass(exception, BaseException):
raise TypeError("Each exception in parameter `save_on_exception` can only be "
raise TypeError("Each exception in parameter `on_exception` can only be "
"`BaseException` type.")
else:
save_on_exception = []
on_exceptions = []

self.save_folder = save_folder
self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_batches = save_every_n_batches
self.save_last = save_last
self.save_topk = save_topk
self.only_state_dict = only_state_dict
self.model_save_fn = model_save_fn
self.save_on_exception = save_on_exception
self.kwargs = kwargs
self.topk_saver = TopkSaver(topk, monitor, larger_better, folder, only_state_dict,
model_save_fn, save_evaluate_results,
save_object, **kwargs)
self.topk = topk
self.save_object = save_object

# 这些参数是专门留给 topk 模式专门使用的;
self._topk_model = {}
self._topn = 0 # 表示目前已经保存了几个最好的模型;
self.every_n_epochs = every_n_epochs
self.every_n_batches = every_n_batches
self.last = last
self.exceptions = on_exceptions

# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 该 folder 只在保存真的要发生的时候再创建。
@property
def need_reproducible_sampler(self) -> bool:
return self.save_object == 'trainer'

def on_after_trainer_initialized(self, trainer, driver):
if self.save_topk is not None:
super().on_after_trainer_initialized(trainer, driver)
if self.save_topk is not None and trainer.evaluator is None:
logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.")
if self.topk_saver.topk_queue: # 需要设置 monitor
if self.topk_saver.monitor is None:
self.topk_saver.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.topk_saver.topk_queue and trainer.evaluator is None:
logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.")

def on_validate_end(self, trainer, results):
self._save_topk(trainer, results)
# 如果发生了保存,则返回的 folder 不为 None
folder = self.topk_saver.save_topk(trainer, results)

def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}'
self.save(trainer, folder_name=folder_name)
if self.save_last:
folder_name = f'{self.folder_prefix}-last'
self.save(trainer, folder_name=folder_name)
if trainer.cur_epoch_idx % self.every_n_epochs == 0:
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}'
self.topk_saver.save(trainer, folder_name=folder_name)
if self.last:
folder_name = f'{self.save_object}-last'
self.topk_saver.save(trainer, folder_name=folder_name)

def on_train_batch_end(self, trainer):
if trainer.global_forward_batches % self.save_every_n_batches == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}'
self.save(trainer, folder_name=folder_name)
if trainer.global_forward_batches % self.every_n_batches == 0:
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}'
self.topk_saver.save(trainer, folder_name=folder_name)

def on_exception(self, trainer, exception: BaseException):
if exception.__class__ in self.save_on_exception:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \
f'exception_{exception.__class__.__name__}'
self.save(trainer=trainer, folder_name=folder_name)
if exception.__class__ in self.exceptions:
folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \
f'exception_{exception.__class__.__name__}'
self.topk_saver.save(trainer, folder_name=folder_name)

def on_save_checkpoint(self, trainer) -> Dict:
"""
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。
topk_model的状态
_real_monitor的值
保存状态,以便之后可以继续使用
"""

states = {}
states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = deepcopy(self._topk_model)
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
if isinstance(self._real_monitor, str):
states['_real_monitor'] = self._real_monitor
states['topk_saver'] = self.topk_saver.state_dict()
return states

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to checkpoint in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)
_topk_model = states['_topk_model']
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
if save_topk is not None and self.save_topk is not None:
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
f"as {save_topk}."
self._topk_model.update(self._topk_model)

self._real_monitor = states["_real_monitor"]

def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict):
"""
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。

:param trainer:
:param results:
:return:
"""
if self.save_topk is not None:
monitor_value = self.get_monitor_value(results=results)
if monitor_value is None:
return
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self._real_monitor}_{monitor_value}"

_should_save = False
if self._topn < self.save_topk:
self._topk_model[folder_name] = monitor_value
self._topn += 1
_should_save = True
else:
_least_valuable_model = (min if self.larger_better else max)(self._topk_model,
key=lambda x: self._topk_model[x])
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]):
self._topk_model[folder_name] = monitor_value
_should_save = True
self._topk_model.pop(_least_valuable_model)
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model))

assert len(self._topk_model) == self.save_topk == self._topn

if _should_save:
self.save(trainer, folder_name=folder_name)

def save(self, trainer, folder_name):
"""
执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。

:param trainer:
:param folder_name:
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建
synchronize_mkdir(folder)
_fn = getattr(trainer, self.save_fn_name)
_fn(
folder=folder,
only_state_dict=self.only_state_dict,
model_save_fn=self.model_save_fn,
**self.kwargs
)

@property
def folder_prefix(self):
raise NotImplementedError("The `folder_prefix` is not specified")

@property
def save_fn_name(self):
raise NotImplementedError("The `save_fn_name` is not specified.")


class ModelCheckpointCallback(CheckpointCallback):
"""
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下

- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- model-last/ # 最后一个 epoch 的保存
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
topk_saver_states = states['topk_saver']
self.topk_saver.load_state_dict(topk_saver_states)

model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
@property
def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save_model'

@property
def callback_name(self):
"""
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return:
"""
return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):
return 'model'


class TrainerCheckpointCallback(CheckpointCallback):
"""
保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下

- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- trainer-last/ # 最后一个 epoch 的保存
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名

model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
@property
def save_fn_name(self):
"""
调用 Trainer 中的哪个函数。

:return:
"""
return 'save'

@property
def callback_name(self):
"""
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return:
"""

return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):
return 'trainer'

+ 61
- 36
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -1,10 +1,12 @@
__all__ = [
'HasMonitorCallback',
'ExecuteOnceBetterMonitor'
'ExecuteOnceBetterMonitor',
'MonitorUtility'
]

from typing import Dict, Union, Any
from abc import ABC
import functools

from fastNLP.core.utils import apply_to_collection
from fastNLP.core.callbacks import Callback
@@ -27,21 +29,13 @@ class CanItemDataType(ABC):
return NotImplemented


class MonitorUtility:
"""
计算 monitor 的相关函数

class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 是否时越大越好
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。
"""
"""
def __init__(self, monitor, larger_better):
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor

def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
@@ -57,26 +51,14 @@ class HasMonitorCallback(Callback):
self.monitor_value = float('inf')
self._real_monitor = self.monitor

def on_after_trainer_initialized(self, trainer, driver):
def itemize_results(self, results):
"""
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。
将结果中有 .item() 方法的都调用一下,使得可以结果可以保存

:param trainer:
:param driver:
:param results:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")


def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
if self.monitor is not None:
self.get_monitor_value(results=sanity_check_res)
return apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())

def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
@@ -85,10 +67,10 @@ class HasMonitorCallback(Callback):
:param results:
:return: 如果为 None ,表明此次没有找到合适的monitor
"""
if len(results)==0:
if len(results) == 0 or self.monitor is None:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
results = self.itemize_results(results)
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
@@ -97,7 +79,7 @@ class HasMonitorCallback(Callback):
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
@@ -165,7 +147,10 @@ class HasMonitorCallback(Callback):
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
monitor = self.monitor
while isinstance(monitor, functools.partial):
monitor = monitor.func
monitor_name = monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
@@ -176,6 +161,46 @@ class HasMonitorCallback(Callback):
return monitor_name



class HasMonitorCallback(MonitorUtility, Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。
:param larger_better: monitor 是否时越大越好
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。
"""
super().__init__(monitor, larger_better)
self.must_have_monitor = must_have_monitor

def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。

:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_monitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")
if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None:
raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.__class__.__name__}"
f" need to watch the monitor:`{self.monitor_name}`.")

def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
if self.monitor is not None:
self.get_monitor_value(results=sanity_check_res)


class ExecuteOnceBetterMonitor(HasMonitorCallback):
def __init__(self, monitor, larger_better, execute_fn):
"""
@@ -183,13 +208,13 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback):

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None
:param larger_better: monitor 是否时越大越好
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。
"""
super().__init__(monitor, larger_better, must_have_monitor=True)
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
self.execute_fn = execute_fn()
self.execute_fn = execute_fn

def on_validate_end(self, trainer, results):
if self.is_better_results(results):

+ 3
- 3
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -23,7 +23,7 @@ class LoadBestModelCallback(HasMonitorCallback):

:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None
:param larger_better: 该 metric 值是否是越大越好。
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。
@@ -72,7 +72,7 @@ class LoadBestModelCallback(HasMonitorCallback):
logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.")
except NotImplementedError:
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to "
f"save best model when launch using script.")
f"save best model when launch using module.")

super().on_after_trainer_initialized(trainer, driver)

@@ -87,7 +87,7 @@ class LoadBestModelCallback(HasMonitorCallback):
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)

def on_train_end(self, trainer):
logger.info(f"Loading best model with {self._real_monitor}: {self.monitor_value}...")
logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...")
if self.real_save_folder:
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)


+ 174
- 0
fastNLP/core/callbacks/more_evaluate_callback.py View File

@@ -0,0 +1,174 @@
__all__ = [
'MoreEvaluateCallback'
]

import os
from typing import Union, Callable, Optional, Dict

from fastNLP.core.log import logger
from .has_monitor_callback import HasMonitorCallback
from .topk_saver import TopkSaver


class MoreEvaluateCallback(HasMonitorCallback):
def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1,
watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True,
evaluate_fn=None, num_eval_sanity_batch=2,
topk=0, topk_monitor=None, topk_larger_better=True,
folder=None, only_state_dict=True, save_object='model', model_save_fn=None,
save_evaluate_results=True, save_kwargs=None,
**kwargs):
"""
当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。

如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存
- folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名

:param dataloaders: 需要评估的数据
:param metrics: 使用的 metrics 。
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回
一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor
的结果,如果当前结果中没有相关的monitor 值请返回 None 。
:param watch_monitor_larger_better: watch_monitor 是否越大越好。
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。
:param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。
:param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 tokp 实现。(1)为 -1 表示每次
evaluate 后都保存;(2)为 0 (默认),表示不保存;(3)为整数,表示保存性能最 topk 个。
:param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找
:param topk_larger_better: topk_monitor 的值是否时越大越好。
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。
:param save_kwargs: dict。更多的保存相关的参数。
:param kwargs: 其它与 Evaluator 相关的初始化参数,如果不传入,将从 Trainer 中获取。请特别留意 evaluate_fn 的设置。
"""
super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better,
must_have_monitor=False)

if watch_monitor is None and evaluate_every is None:
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.")
if watch_monitor is not None and evaluate_every is not None:
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.")
self.watch_monitor = watch_monitor

if topk_monitor is not None and topk == 0:
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.")
if topk != 0 and topk_monitor is None:
raise RuntimeError("`topk` is set, but `topk_monitor` is None.")
assert save_object in ['trainer', 'model']

self.dataloaders = dataloaders
self.metrics = metrics
self.evaluate_every = evaluate_every
self.evaluate_fn = evaluate_fn
self.num_eval_sanity_batch = num_eval_sanity_batch
if save_kwargs is None:
save_kwargs = {}
self.topk_saver = TopkSaver(topk=topk, monitor=topk_monitor, larger_better=topk_larger_better,
folder=folder, only_state_dict=only_state_dict,
model_save_fn=model_save_fn, save_evaluate_results=save_evaluate_results,
save_object=save_object, **save_kwargs)
self.kwargs = kwargs

@property
def need_reproducible_sampler(self) -> bool:
return self.topk_saver.save_object == 'trainer'

def on_after_trainer_initialized(self, trainer, driver):
# 如果是需要 watch 的,不能没有 evaluator
if self.watch_monitor is not None:
assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \
f"evaluate_dataloaders is provided in Trainer."

if trainer.evaluate_fn is self.evaluate_fn:
logger.warning_once("The `evaluate_fn` is the same as in Trainer, there seems no need to use "
"`MoreEvaluateCallback`.")

# 初始化 evaluator , 同时避免调用 super 对 monitor 赋值
kwargs = {
'model': self.kwargs.get('model', trainer.model),
'dataloaders': self.dataloaders,
'metrics': self.metrics,
'driver': self.kwargs.get('driver', trainer.driver),
'device': self.kwargs.get('device', trainer.device),
'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn),
'evaluate_fn': self.evaluate_fn,
'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping),
'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping),
'fp16': self.kwargs.get('fp16', trainer.fp16),
'use_dist_sampler': self.kwargs.get('use_dist_sampler',
trainer.kwargs.get('eval_use_dist_sampler', None)),
'progress_bar': self.kwargs.get('progress_bar', trainer.kwargs.get('progress_bar', 'auto')),
'verbose': self.kwargs.get('verbose', 1)
}

for key, value in self.kwargs.items():
if key not in kwargs:
kwargs[key] = value
from fastNLP.core.controllers.evaluator import Evaluator
self.evaluator = Evaluator(**kwargs)
if self.num_eval_sanity_batch>0:
results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch)
self.topk_saver.get_monitor_value(results)

def on_validate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True):
results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results)

def on_train_epoch_end(self, trainer):
if self.watch_monitor is not None:
return
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
validate_every = -self.evaluate_every
if trainer.cur_epoch_idx % validate_every == 0:
results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results)

def on_train_batch_end(self, trainer):
if self.watch_monitor is not None:
return
if callable(self.evaluate_every):
if self.evaluate_every(self):
results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results)
elif self.evaluate_every > 0 and trainer.global_forward_batches % self.evaluate_every == 0:
results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results)

def on_save_checkpoint(self, trainer) -> Dict:
states = {'topk_saver': self.topk_saver.state_dict()}
if isinstance(self._real_monitor, str):
states['_real_monitor'] = self._real_monitor
states['monitor_value'] = self.monitor_value
return states

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
topk_saver_states = states['topk_saver']
self.topk_saver.load_state_dict(topk_saver_states)
if '_real_monitor' in states:
self._real_monitor = states["_real_monitor"]
self.monitor_value = states['monitor_value']

@property
def callback_name(self):
metric_names = '+'.join(sorted(self.metrics.keys()))
return f'more_evaluate_callback#metric_name-{metric_names}#monitor-{self.monitor_name}#topk_saver:{self.topk_saver}'


+ 4
- 3
fastNLP/core/callbacks/progress_callback.py View File

@@ -9,7 +9,6 @@ __all__ = [
]

from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger

@@ -42,7 +41,8 @@ class RichCallback(ProgressCallback):
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有
相关的 monitor 值请返回 None 。
:param larger_better: 是否是 monitor 的结果越大越好。
:param format_json: 是否格式化 json 再打印
"""
@@ -135,7 +135,8 @@ class RawTextCallback(ProgressCallback):
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到
完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有
相关的 monitor 值请返回 None 。
:param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印
"""


+ 246
- 0
fastNLP/core/callbacks/topk_saver.py View File

@@ -0,0 +1,246 @@
import json
import os
from copy import deepcopy
from pathlib import Path
from typing import Optional, Dict, Tuple

from fastNLP.core.utils import rank_zero_rm
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.envs import rank_zero_call
from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME
from .has_monitor_callback import MonitorUtility


class Saver:
def __init__(self, folder, only_state_dict, model_save_fn, **kwargs):
"""
执行保存的对象。保存的文件组织结构为
- folder # 当前初始化的参数
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- folder_name # 由 save() 调用时传入。

:param folder:
:param only_state_dict:
:param model_save_fn:
:param kwargs:
"""
if folder is None:
logger.warning(
"Parameter `folder` is None, and we will use the current work directory to find and load your model.")
folder = Path.cwd()
folder = Path(folder)
if not folder.exists():
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!")
elif folder.is_file():
raise ValueError("Parameter `folder` should be a directory instead of a file.")

self.folder = folder
self.only_state_dict = only_state_dict
self.model_save_fn = model_save_fn
self.kwargs = kwargs
self.eval_results = kwargs.get('eval_results', True)
self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])

@rank_zero_call
def save(self, save_fn, folder_name):
"""
执行保存的函数,将数据保存在 folder/timestamp/folder_name 下。其中 folder 为用户在初始化指定,
timestamp 为当前脚本的启动时间。

:param save_fn: 调用的保存函数,应该可接受参数 folder:str, only_state_dict: bool, model_save_fn: callable, kwargs
:param folder_name: 保存的 folder 名称,将被创建。
:return: 返回实际发生保存的 folder 绝对路径。如果为 None 则没有创建。
"""
folder = self.timestamp_path.joinpath(folder_name)
folder.mkdir(parents=True, exist_ok=True)
save_fn(
folder=folder,
only_state_dict=self.only_state_dict,
model_save_fn=self.model_save_fn,
**self.kwargs
)
return str(os.path.abspath(folder))

@rank_zero_call
def save_json(self, results, path):
"""
以 json 格式保存 results 到 path 中

:param results:
:param path:
:return:
"""
with open(path, 'w', encoding='utf8') as f:
json.dump(results, f, indent=2)

@rank_zero_call
def rm(self, folder_name):
"""
移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。

:param folder_name:
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
rank_zero_rm(folder)

def state_dict(self):
states = {
'timestamp_path': str(self.timestamp_path),
}
return states

def load_state_dict(self, states):
timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to save checkpoint in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)

def __str__(self):
return 'saver' # saver是无状态的,不需要有特定名字


class TopkQueue:
def __init__(self, topk):
"""
用于维护处于 topk 的 key, value 对。

:param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。
"""
assert isinstance(topk, int)
self.topk = topk
self.topk_dict = {} # 其中 key 为保存的

def push(self, key, value) -> Optional[Tuple[str, float]]:
"""
将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给
挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回
推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。

:param str key:
:param float value: 如果为 None, 则不做任何操作。
:return: (1)返回输入的 (key, value) ,说明不满足 topk; (2) 返回(None, None),说明满足 topk 且没有被挤出过去的记录; (3)
返回非输入的 (key, value) , 说明输入满足 topk,且挤出了之前的记录。
"""
if value is None:
return key, value
if self.topk < 0:
return None, None
if self.topk == 0:
return key, value
if len(self.topk_dict)<self.topk:
self.topk_dict[key] = value
return None, None
min_key = min(self.topk_dict, key=lambda x:self.topk_dict[x])
if self.topk_dict[min_key] > value:
return key, value
else:
min_value = self.topk_dict.pop(min_key)
self.topk_dict[key] = value
return min_key, min_value

def state_dict(self):
return deepcopy(self.topk_dict)

def load_state_dict(self, states):
self.topk_dict.update(states)

def __str__(self):
return f'topk-{self.topk}'

def __bool__(self):
# 仅当 topk 为 0 时,表明该 topk_queue 无意义。
return self.topk != 0


class TopkSaver(MonitorUtility, Saver):
def __init__(self, topk, monitor, larger_better, folder, only_state_dict,
model_save_fn, save_evaluate_results,
save_object, **kwargs):
"""
用来保存识别 tokp 模型并保存。

:param topk:
:param monitor:
:param larger_better:
:param folder:
:param only_state_dict:
:param model_save_fn:
:param save_evaluate_results:
:param save_object:
:param kwargs:
"""
MonitorUtility.__init__(self, monitor, larger_better)
Saver.__init__(self, folder, only_state_dict, model_save_fn, **kwargs)

if monitor is not None and topk == 0:
raise RuntimeError("`monitor` is set, but `topk` is 0.")
if topk != 0 and monitor is None:
raise RuntimeError("`topk` is set, but `monitor` is None.")

assert save_object in ['trainer', 'model']

self.saver = Saver(folder, only_state_dict, model_save_fn, **kwargs)
self.topk_queue = TopkQueue(topk)
self.save_evaluate_results = save_evaluate_results
self.save_object = save_object
self.save_fn_name = 'save' if save_object == 'trainer' else 'save_model'

@rank_zero_call
def save_topk(self, trainer, results: Dict) -> Optional[str]:
"""
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。

:param trainer:
:param results:
:return:
"""
if self.monitor is not None and self.topk_queue:
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
key = f"{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self.monitor_name}_{monitor_value}"
pop_key, pop_value = self.topk_queue.push(key, monitor_value if self.larger_better else -monitor_value)
if pop_key == key: # 说明不足以构成 topk,被退回了
return None
folder = self.save(trainer, key)
if self.save_evaluate_results and folder:
try:
self.save_json(self.itemize_results(results),
os.path.join(folder, FASTNLP_EVALUATE_RESULT_FILENAME))
except:
logger.exception(f"Fail to save evaluate results to {folder}")

if pop_key and pop_key != key: # 说明需要移除之前的 topk
self.rm(pop_key)
return folder

def save(self, trainer, folder_name):
fn = getattr(trainer, self.save_fn_name)
return super().save(fn, folder_name)

def state_dict(self):
states = {
'topk_queue': self.topk_queue.state_dict(),
'saver': self.saver.state_dict()
}
if isinstance(self._real_monitor, str):
states['_real_monitor'] = self._real_monitor

return states

def load_state_dict(self, states):
topk_queue_states = states['topk_queue']
saver_states = states['saver']
self.topk_queue.load_state_dict(topk_queue_states)
self.saver.load_state_dict(saver_states)
if '_real_monitor' in states:
self._real_monitor = states["_real_monitor"]

def __str__(self):
return f'topk-{self.topk_queue}#saver-{self.saver}#save_object-{self.save_object}'

+ 3
- 2
fastNLP/core/callbacks/utils.py View File

@@ -1,4 +1,6 @@
from typing import Optional, Union
import os

from fastNLP.core.log.logger import logger
from difflib import SequenceMatcher
from fastNLP.core.utils.utils import _get_fun_msg
@@ -15,7 +17,7 @@ def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有
找到对应的 monitor
"""
if len(res)==0:
if len(res) == 0 or monitor is None:
return monitor, None

if callable(monitor):
@@ -56,4 +58,3 @@ def _match_length(a:str, b:str)->int:
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long))
return match.size



+ 5
- 7
fastNLP/core/controllers/evaluator.py View File

@@ -38,7 +38,7 @@ class Evaluator:
driver: Union[str, Driver] = 'torch',
device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None,
evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable
evaluate_fn: Optional[str] = None,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
@@ -57,8 +57,9 @@ class Evaluator:
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的
batch_step_fn 函数。
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`;
默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数;
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
@@ -69,6 +70,7 @@ class Evaluator:
:param kwargs:
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout
与 batch normalization 将会关闭。默认为True。
TODO 还没完成。
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数
@@ -119,10 +121,6 @@ class Evaluator:
self._metric_wrapper = None
_ = self.metrics_wrapper # 触发检查

if self._dist_sampler is not None and not self.driver.is_distributed():
logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause "
"different process evaluating on different data.")

if evaluate_fn is not None and not isinstance(evaluate_fn, str):
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.")
self._evaluate_step, self._evaluate_step_signature_fn = \


+ 61
- 64
fastNLP/core/controllers/trainer.py View File

@@ -14,10 +14,10 @@ __all__ = [

from .loops import Loop, TrainBatchLoop
from .utils import State, TrainerState
from .utils.utils import check_validate_every
from .utils.utils import check_evaluate_every
from .evaluator import Evaluator
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.drivers import Driver
@@ -26,7 +26,7 @@ from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nu
from fastNLP.core.utils.utils import _check_valid_parameters_number
from fastNLP.envs import rank_zero_call
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.utils.exceptions import EarlyStopException


@@ -94,9 +94,9 @@ class Trainer(TrainerEventTrigger):
evaluate_step 这个函数,如果没有则使用 forward 函数。
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类;
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()};
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
@@ -124,7 +124,7 @@ class Trainer(TrainerEventTrigger):
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
@@ -214,13 +214,13 @@ class Trainer(TrainerEventTrigger):

""" 设置内部的 Evaluator """
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.")
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")

if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.")
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.")

self.metrics = metrics
self.validate_every = evaluate_every
self.evaluate_every = evaluate_every

self.driver.setup()
self.driver.barrier()
@@ -235,7 +235,7 @@ class Trainer(TrainerEventTrigger):
self.monitor = monitor
self.larger_better = larger_better
if metrics is not None and evaluate_dataloaders is not None:
check_validate_every(evaluate_every)
check_evaluate_every(evaluate_every)
self.evaluator = Evaluator(
model=model,
dataloaders=evaluate_dataloaders,
@@ -248,7 +248,7 @@ class Trainer(TrainerEventTrigger):
output_mapping=output_mapping,
fp16=fp16,
verbose=0,
use_dist_sampler=kwargs.get("use_eval_dist_sampler", None),
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None),
progress_bar=kwargs.get('progress_bar', 'auto')
)

@@ -261,11 +261,14 @@ class Trainer(TrainerEventTrigger):
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager.has_trainer_checkpoint)
reproducible=self.callback_manager._need_reproducible_sampler)

self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
self.on_after_trainer_initialized(self.driver)

self.evaluate_batch_step_fn = evaluate_batch_step_fn
self.kwargs = kwargs

self.on_after_trainer_initialized(self.driver)
self.driver.barrier()

def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
@@ -364,10 +367,10 @@ class Trainer(TrainerEventTrigger):
:return:
"""
if self.evaluator is not None:
if callable(self.validate_every):
if self.validate_every(self):
if callable(self.evaluate_every):
if self.evaluate_every(self):
self.run_evaluate()
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0:
elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0:
self.run_evaluate()

def epoch_validate(self):
@@ -377,8 +380,8 @@ class Trainer(TrainerEventTrigger):
:return:
"""
if self.evaluator is not None:
if isinstance(self.validate_every, int) and self.validate_every < 0:
validate_every = -self.validate_every
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
validate_every = -self.evaluate_every
if self.cur_epoch_idx % validate_every == 0:
self.run_evaluate()

@@ -427,7 +430,7 @@ class Trainer(TrainerEventTrigger):
self._custom_callbacks[None] = []
if self.marker is not None:
if len(self._custom_callbacks[self.marker]) == 0:
print(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched "
logger.info(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched "
f"`{self.marker}` that is added through function `Trainer.on`")
_own_callbacks += self._custom_callbacks[self.marker]
for each_callback in _own_callbacks:
@@ -528,10 +531,10 @@ class Trainer(TrainerEventTrigger):
r"""
用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现;

:param folder: 保存模型的地址;
:param only_state_dict: 是否只保存模型的 `state_dict`;
:param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则在这个文件夹下创建 fastnlp_model.pkl.tar 文件。
:param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 `state_dict`;
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它;
:param kwargs:
"""

self.on_save_model()
@@ -568,14 +571,19 @@ class Trainer(TrainerEventTrigger):
self.on_load_model()
self.driver.barrier()
if not isinstance(folder, (io.BytesIO, BinaryIO)):
if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
rank_zero_call(model_load_fn)(folder)
else:
if isinstance(folder, str):
folder = Path(folder)
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
try:
if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
rank_zero_call(model_load_fn)(folder)
else:
if isinstance(folder, str):
folder = Path(folder)
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
except FileNotFoundError as e:
if FASTNLP_MODEL_FILENAME not in os.listdir(folder):
logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.")
raise e
else:
if model_load_fn is not None:
raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being "
@@ -585,11 +593,13 @@ class Trainer(TrainerEventTrigger):

def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs):
r"""
用于断点重训 Trainer 的保存函数;
用于断点重训 Trainer 的保存函数

:param folder:
:param only_state_dict:
:param model_save_fn:
:param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。
如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件。
:param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重。
:param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder
参数),不必返回任何东西。
:param kwargs:
:return:
"""
@@ -602,17 +612,6 @@ class Trainer(TrainerEventTrigger):
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0)
}

# 3. validate filter state;
if self.evaluator is not None:
val_filter_state = {}
if hasattr(self.step_validate, "__fastNLP_filter__"):
val_filter_state["step_validate"] = self.step_validate.__fastNLP_filter__.state_dict()
if hasattr(self.epoch_validate, "__fastNLP_filter__"):
val_filter_state["epoch_validate"] = self.epoch_validate.__fastNLP_filter__.state_dict()
states["val_filter_state"] = val_filter_state
else:
states["val_filter_state"] = None

if isinstance(folder, str):
folder = Path(folder)

@@ -649,32 +648,30 @@ class Trainer(TrainerEventTrigger):
dataloader = self.dataloader
if not resume_training:
dataloader = None

if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable`.")
rank_zero_call(model_load_fn)(folder)
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
else:
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
try:
if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable`.")
rank_zero_call(model_load_fn)(folder)
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
else:
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
except FileNotFoundError as e:
if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder):
logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.")
elif FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder):
logger.error(f"fastNLP Trainer checkpoint file:{FASTNLP_CHECKPOINT_FILENAME} is not found in {folder}.")
raise e

if not resume_training:
return

self.dataloader = states.pop('dataloader')

# 2. validate filter state;
if self.evaluator is not None:
val_filter_state = states["val_filter_state"]
if hasattr(self.step_validate, "__fastNLP_filter__"):
self.step_validate.__fastNLP_filter__.load_state_dict(val_filter_state["step_validate"])
if hasattr(self.epoch_validate, "__fastNLP_filter__"):
self.epoch_validate.__fastNLP_filter__.load_state_dict(val_filter_state["epoch_validate"])

# 3. 恢复 trainer_state 的状态;
# 1. 恢复 trainer_state 的状态;
self.trainer_state.load_state_dict(states["trainer_state"])

# 4. 修改 trainer_state.batch_idx_in_epoch
# 2. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立


+ 1
- 1
fastNLP/core/controllers/utils/utils.py View File

@@ -126,7 +126,7 @@ class _TruncatedDataLoader:
return getattr(self.dataloader, item)


def check_validate_every(validate_every):
def check_evaluate_every(validate_every):
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):


+ 3
- 2
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -11,6 +11,7 @@ from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
@@ -48,8 +49,8 @@ class TorchDataLoader(DataLoader):
"""

def __init__(self, dataset, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,


+ 0
- 1
fastNLP/core/drivers/driver.py View File

@@ -380,7 +380,6 @@ class Driver(ABC):
"""
# 单卡 driver 不需要这个函数;
if self._pids is not None:

exc_type, exc_value, exc_traceback_obj = sys.exc_info()
_write_exc_info = {
'exc_type': str(exc_type.__name__),


+ 1
- 1
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -526,7 +526,7 @@ class TorchDDPDriver(TorchDriver):

def barrier(self):
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
torch.distributed.barrier(async_op=True)
torch.distributed.barrier(async_op=False)

def is_distributed(self):
return True


+ 8
- 4
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -9,8 +9,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from pathlib import Path
if _NEED_IMPORT_TORCH:
import torch
from torch.utils.data import DataLoader, IterableDataset, RandomSampler, Sampler, BatchSampler, Dataset
from torch.utils.data import DataLoader, IterableDataset, Sampler, BatchSampler, Dataset
from torch.optim import Optimizer
from torch.utils.data import RandomSampler as TorchRandomSampler
_reduces = {
'sum': torch.max,
'min': torch.min,
@@ -30,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler


class TorchDriver(Driver):
@@ -211,8 +212,8 @@ class TorchDriver(Driver):

states['sampler_states'] = sampler_states
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training '
'state.')

# 2. 保存模型的状态;
if should_save_model:
@@ -283,6 +284,9 @@ class TorchDriver(Driver):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif isinstance(dataloader_args.sampler, TorchRandomSampler):
sampler = RandomSampler(dataloader_args.sampler.data_source)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")


+ 1
- 1
fastNLP/core/metrics/accuracy.py View File

@@ -19,7 +19,7 @@ class Accuracy(Metric):

:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。
"""
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)


+ 2
- 0
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -84,6 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
:param rank:
:return:
"""
assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \
f"number of samples({len(self.dataset)})."
assert num_replicas>0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0<=rank<num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;


+ 3
- 3
fastNLP/core/utils/__init__.py View File

@@ -24,8 +24,8 @@ __all__ = [
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'synchronize_safe_rm',
'synchronize_mkdir'
'rank_zero_rm',
'rank_zero_mkdir'
]

from .cache_results import cache_results
@@ -37,6 +37,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir



+ 15
- 18
fastNLP/core/utils/utils.py View File

@@ -38,8 +38,8 @@ __all__ = [
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'synchronize_safe_rm',
'synchronize_mkdir'
'rank_zero_rm',
'rank_zero_mkdir'
]


@@ -629,7 +629,7 @@ def wait_filepath(path, exist=True):



def synchronize_safe_rm(path: Optional[Union[str, Path]]):
def rank_zero_rm(path: Optional[Union[str, Path]]):
"""
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
@@ -638,15 +638,14 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]):
:param path:
:return:
"""
if path is None:
return
if isinstance(path, str):
path = Path(path)
if not path.exists():
return
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
if path is None:
return
if isinstance(path, str):
path = Path(path)
if not path.exists():
return
_recursive_rm(path)
wait_filepath(path, exist=False)


def _recursive_rm(path: Path):
@@ -662,21 +661,19 @@ def _recursive_rm(path: Path):
path.rmdir()


def synchronize_mkdir(path: Optional[Union[str, Path]]):
def rank_zero_mkdir(path: Optional[Union[str, Path]]):
"""
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数;
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。

"""
if path is None:
return
if isinstance(path, str):
path = Path(path)

if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
path.mkdir(parents=True, exist_ok=True)
if path is None:
return
if isinstance(path, str):
path = Path(path)

wait_filepath(path, exist=True)
path.mkdir(parents=True, exist_ok=True)


def get_class_that_defined_method(method):


+ 2
- 2
fastNLP/envs/env.py View File

@@ -49,7 +49,7 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC'

# todo 注释 直接使用的变量
# 保存各种内容时的默认名称
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar"
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar"
FASTNLP_EVALUATE_RESULT_FILENAME = 'fastnlp_evaluate_results.json'

+ 62
- 109
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -7,13 +7,14 @@ from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path
import re
import time

from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK

from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDatset
from torchmetrics import Accuracy
@@ -80,44 +81,21 @@ def test_model_checkpoint_callback_1(
version,
only_state_dict
):
# def test_model_checkpoint_callback_1(
# model_and_optimizers: TrainerParameters,
# driver='torch_ddp',
# device=[0, 1],
# version=1,
# only_state_dict=True
# ):
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

if version == 0:
callbacks = [
ModelCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=1,
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None,
save_last=False,
save_on_exception=None,
only_state_dict=only_state_dict
)
]
elif version == 1:
callbacks = [
ModelCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=3,
save_every_n_batches=None,
save_topk=2,
save_last=True,
save_on_exception=None,
only_state_dict=only_state_dict
)
]
if version == 0:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=1, every_n_batches=123, last=False, on_exceptions=None, topk=0,
monitor=None, only_state_dict=only_state_dict, save_object='model')
]
elif version == 1:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=3, every_n_batches=None, last=True, on_exceptions=None, topk=2,
monitor="acc", only_state_dict=only_state_dict, save_object='model')
]

try:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
@@ -134,7 +112,7 @@ def test_model_checkpoint_callback_1(
)

trainer.run()
print("Finish train")
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
# 检查生成保存模型文件的数量是不是正确的;
if version == 0:
@@ -217,8 +195,7 @@ def test_model_checkpoint_callback_1(
trainer.run()

finally:
synchronize_safe_rm(path)
pass
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()
@@ -233,30 +210,23 @@ def test_model_checkpoint_callback_2(
device,
only_state_dict
):
path = Path.cwd().joinpath("test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
try:
path = Path.cwd().joinpath("test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

from fastNLP.core.callbacks.callback_events import Events

@Trainer.on(Events.on_train_epoch_end)
def raise_exception(trainer):
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4:
raise NotImplementedError

callbacks = [
ModelCheckpointCallback(
monitor="acc1",
save_folder=path,
save_every_n_epochs=None,
save_every_n_batches=None,
save_topk=None,
save_last=False,
save_on_exception=NotImplementedError,
only_state_dict=only_state_dict
),
]
from fastNLP.core.callbacks.callback_events import Events

@Trainer.on(Events.on_train_epoch_end)
def raise_exception(trainer):
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4:
raise NotImplementedError

callbacks = [
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False,
on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict,
save_object='model'),
]

try:
with pytest.raises(NotImplementedError):
trainer = Trainer(
model=model_and_optimizers.model,
@@ -315,14 +285,14 @@ def test_model_checkpoint_callback_2(
trainer.run()

finally:
synchronize_safe_rm(path)
rank_zero_rm(path)
# pass

if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1])
@pytest.mark.parametrize("only_state_dict", [True, False])
@magic_argv_env_context
@@ -333,37 +303,21 @@ def test_trainer_checkpoint_callback_1(
version,
only_state_dict
):
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

if version == 0:
callbacks = [
TrainerCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=7,
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None,
save_last=False,
save_on_exception=None,
only_state_dict=only_state_dict
)
]
elif version == 1:
callbacks = [
TrainerCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=None,
save_every_n_batches=None,
save_topk=2,
save_last=True,
save_on_exception=None,
only_state_dict=only_state_dict
)
]
if version == 0:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=7, every_n_batches=123, last=False, on_exceptions=None, topk=0,
monitor=None, only_state_dict=only_state_dict, save_object='trainer')
]
elif version == 1:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=True, on_exceptions=None,
topk=2, monitor="acc", only_state_dict=only_state_dict, save_object='trainer')
]

try:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
@@ -461,8 +415,7 @@ def test_trainer_checkpoint_callback_1(
trainer.run()

finally:
synchronize_safe_rm(path)
pass
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()
@@ -594,12 +547,12 @@ def test_trainer_checkpoint_callback_2(
callbacks = [
TrainerCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=None,
save_every_n_batches=50,
save_topk=None,
save_last=False,
save_on_exception=None,
folder=path,
every_n_epochs=None,
every_n_batches=50,
topk=None,
last=False,
on_exception=None,
model_save_fn=model_save_fn
)
]
@@ -607,12 +560,12 @@ def test_trainer_checkpoint_callback_2(
callbacks = [
TrainerCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=None,
save_every_n_batches=None,
save_topk=1,
save_last=True,
save_on_exception=None,
folder=path,
every_n_epochs=None,
every_n_batches=None,
topk=1,
last=True,
on_exception=None,
model_save_fn=model_save_fn
)
]
@@ -710,7 +663,7 @@ def test_trainer_checkpoint_callback_2(
trainer.run()

finally:
synchronize_safe_rm(path)
rank_zero_rm(path)
# pass

if dist.is_initialized():


+ 263
- 0
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -0,0 +1,263 @@
"""
测试 more_evaluate_callback
(1)能不能正确 evaluate ;
(2) 能不能保存 topk 并load进来进行训练

"""
import pytest



import os
import pytest
from typing import Any
from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path
import re

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK

from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDatset
from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger
from fastNLP.core.callbacks import MoreEvaluateCallback


@dataclass
class ArgMaxDatasetConfig:
num_labels: int = 10
feature_dimension: int = 10
data_num: int = 100
seed: int = 0

batch_size: int = 4
shuffle: bool = True



@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
evaluate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None
more_metrics: Any = None


@pytest.fixture(scope="module", params=[0], autouse=True)
def model_and_optimizers(request):
trainer_params = TrainerParameters()

trainer_params.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels,
feature_dimension=ArgMaxDatasetConfig.feature_dimension
)
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)
dataset = TorchArgMaxDatset(
feature_dimension=ArgMaxDatasetConfig.feature_dimension,
data_num=ArgMaxDatasetConfig.data_num,
seed=ArgMaxDatasetConfig.seed
)
_dataloader = DataLoader(
dataset=dataset,
batch_size=ArgMaxDatasetConfig.batch_size,
shuffle=True
)

class LossMetric(Metric):
def __init__(self):
super().__init__()
self.register_element('loss')

def update(self, loss):
self.loss += loss.item()

def get_metric(self) -> dict:
return self.loss.item()

trainer_params.train_dataloader = _dataloader
trainer_params.evaluate_dataloaders = _dataloader
trainer_params.metrics = {'loss': LossMetric()}

trainer_params.more_metrics = {"acc": Accuracy()}

return trainer_params


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1])
@pytest.mark.parametrize("only_state_dict", [True, False])
@magic_argv_env_context
def test_model_more_evaluate_callback_1(
model_and_optimizers: TrainerParameters,
driver,
device,
version,
only_state_dict
):
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

if version == 0:
callbacks = [
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders,
metrics=model_and_optimizers.more_metrics,
evaluate_every=-1,
folder=path, topk=-1,
topk_monitor='acc', only_state_dict=only_state_dict, save_object='model')
]
elif version == 1:
callbacks = [
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders,
metrics=model_and_optimizers.more_metrics,
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False,
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict,
save_object='model')
]
n_epochs = 5
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=n_epochs,
callbacks=callbacks,
output_from_new_proc="all",
evaluate_fn='train_step'
)

trainer.run()

all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
# 检查生成保存模型文件的数量是不是正确的;
if version == 0:
assert len(all_saved_model_paths) == n_epochs
elif version == 1:
assert len(all_saved_model_paths) == 1

for folder in all_saved_model_paths:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=2,
output_from_new_proc="all",
evaluate_fn='train_step'
)
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder)
trainer.load_model(folder, only_state_dict=only_state_dict)

trainer.run()

finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1])
@pytest.mark.parametrize("only_state_dict", [True, False])
@magic_argv_env_context
def test_trainer_checkpoint_callback_1(
model_and_optimizers: TrainerParameters,
driver,
device,
version,
only_state_dict
):
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

if version == 0:
callbacks = [
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders,
metrics=model_and_optimizers.more_metrics,
evaluate_every=-1,
folder=path, topk=-1,
topk_monitor='acc', only_state_dict=only_state_dict, save_object='trainer')
]
elif version == 1:
callbacks = [
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders,
metrics=model_and_optimizers.more_metrics,
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False,
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict,
save_object='trainer')
]
n_epochs = 5
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=n_epochs,
callbacks=callbacks,
output_from_new_proc="all",
evaluate_fn='train_step'
)

trainer.run()

all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
# 检查生成保存模型文件的数量是不是正确的;
if version == 0:
assert len(all_saved_model_paths) == n_epochs
elif version == 1:
assert len(all_saved_model_paths) == 1

for folder in all_saved_model_paths:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=7,
output_from_new_proc="all",
evaluate_fn='train_step'
)
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder)
trainer.load(folder, only_state_dict=only_state_dict)

trainer.run()

finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()

+ 2
- 2
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -15,7 +15,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm


@dataclass
@@ -239,7 +239,7 @@ def test_trainer_output_from_new_proc(
assert err_path.exists()

path = Path(os.path.abspath(output_from_new_proc))
synchronize_safe_rm(path)
rank_zero_rm(path)


@pytest.mark.parametrize("driver,device", [("torch", [1, 2])])


+ 7
- 7
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -11,7 +11,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm

import paddle
from paddle.io import DataLoader, BatchSampler
@@ -578,11 +578,11 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict):
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
rank_zero_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
rank_zero_rm(path + ".pdiparams")
rank_zero_rm(path + ".pdiparams.info")
rank_zero_rm(path + ".pdmodel")

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
@@ -652,7 +652,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
@@ -730,4 +730,4 @@ def test_save_and_load_with_randomsampler(only_state_dict):
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

+ 8
- 8
tests/core/log/test_logger.py View File

@@ -6,7 +6,7 @@ import logging
import re

from fastNLP.envs.env import FASTNLP_LAUNCH_TIME
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm
from fastNLP.core.log.logger import logger

from tests.helpers.utils import magic_argv_env_context, recover_logger
@@ -56,7 +56,7 @@ def test_add_file_ddp_1_torch():
pattern = re.compile(msg)
assert len(pattern.findall(line)) == 1

synchronize_safe_rm(filepath)
rank_zero_rm(filepath)
dist.barrier()
dist.destroy_process_group()

@@ -105,7 +105,7 @@ def test_add_file_ddp_2_torch():
pattern = re.compile(msg)
assert len(pattern.findall(line)) == 1
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

dist.barrier()
dist.destroy_process_group()
@@ -155,7 +155,7 @@ def test_add_file_ddp_3_torch():
pattern = re.compile(msg)
assert len(pattern.findall(line)) == 1

synchronize_safe_rm(file)
rank_zero_rm(file)
dist.barrier()
dist.destroy_process_group()

@@ -202,7 +202,7 @@ def test_add_file_ddp_4_torch():
pattern = re.compile(msg)
assert len(pattern.findall(line)) == 1
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

dist.barrier()
dist.destroy_process_group()
@@ -225,7 +225,7 @@ class TestLogger:
line = ''.join([l for l in f])
assert self.msg in line
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

@recover_logger
def test_add_file_2(self):
@@ -243,7 +243,7 @@ class TestLogger:
line = ''.join([l for l in f])
assert self.msg in line
finally:
synchronize_safe_rm(origin_path)
rank_zero_rm(origin_path)

@recover_logger
def test_add_file_3(self):
@@ -279,7 +279,7 @@ class TestLogger:
line = ''.join([l for l in f])
assert self.msg in line
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

@recover_logger
def test_stdout(self, capsys):


+ 10
- 10
tests/core/utils/test_cache_results.py View File

@@ -8,7 +8,7 @@ import sys
from fastNLP.core.utils.cache_results import cache_results
from tests.helpers.common.utils import check_time_elapse

from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm


def get_subprocess_results(cmd):
@@ -56,7 +56,7 @@ class TestCacheResults:
res = demo()

finally:
synchronize_safe_rm(cache_fp)
rank_zero_rm(cache_fp)

def test_cache_save_refresh(self):
cache_fp = 'demo.pkl'
@@ -70,7 +70,7 @@ class TestCacheResults:
with check_time_elapse(1, op='ge'):
res = demo()
finally:
synchronize_safe_rm(cache_fp)
rank_zero_rm(cache_fp)

def test_cache_no_func_change(self):
cache_fp = os.path.abspath('demo.pkl')
@@ -91,7 +91,7 @@ class TestCacheResults:
with check_time_elapse(1, op='lt'):
res = demo()
finally:
synchronize_safe_rm('demo.pkl')
rank_zero_rm('demo.pkl')

def test_cache_func_change(self, capsys):
cache_fp = 'demo.pkl'
@@ -121,7 +121,7 @@ class TestCacheResults:
assert 'is different from its last cache' not in output[0]

finally:
synchronize_safe_rm('demo.pkl')
rank_zero_rm('demo.pkl')

def test_cache_check_hash(self):
cache_fp = 'demo.pkl'
@@ -152,7 +152,7 @@ class TestCacheResults:
assert 'is different from its last cache' in output[0]

finally:
synchronize_safe_rm('demo.pkl')
rank_zero_rm('demo.pkl')

# 外部 function 改变也会 导致改变
def test_refer_fun_change(self):
@@ -177,7 +177,7 @@ class TestCacheResults:
assert 'is different from its last cache' in res

finally:
synchronize_safe_rm(cache_fp)
rank_zero_rm(cache_fp)

# 外部 method 改变也会 导致改变
def test_refer_class_method_change(self):
@@ -202,7 +202,7 @@ class TestCacheResults:
assert 'is different from its last cache' in res

finally:
synchronize_safe_rm(cache_fp)
rank_zero_rm(cache_fp)

def test_duplicate_keyword(self):
with pytest.raises(RuntimeError):
@@ -240,7 +240,7 @@ class TestCacheResults:
results = cache()
assert (1, 2) == results
finally:
synchronize_safe_rm('demo/')
rank_zero_rm('demo/')

def test_result_none_error(self):
@cache_results('demo.pkl')
@@ -251,7 +251,7 @@ class TestCacheResults:
with pytest.raises(RuntimeError):
results = cache()
finally:
synchronize_safe_rm('demo.pkl')
rank_zero_rm('demo.pkl')


if __name__ == '__main__':


+ 2
- 2
tests/envs/test_set_backend.py View File

@@ -2,7 +2,7 @@ import os

from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm


def test_dump_fastnlp_envs():
@@ -14,4 +14,4 @@ def test_dump_fastnlp_envs():
assert filepath in output[0]
assert os.path.exists(filepath)
finally:
synchronize_safe_rm(filepath)
rank_zero_rm(filepath)

+ 2
- 2
tests/modules/mix_modules/test_mix_module.py View File

@@ -9,7 +9,7 @@ import numpy as np

from fastNLP.modules.mix_modules.mix_module import MixModule
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm


############################################################################
@@ -227,7 +227,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase):

self.assertDictEqual(state_dict, new_state_dict)
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

def if_device_correct(self, device):



Loading…
Cancel
Save