Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
9675eae798
15 changed files with 334 additions and 213 deletions
  1. +3
    -2
      fastNLP/core/callbacks/__init__.py
  2. +3
    -3
      fastNLP/core/callbacks/callback_manager.py
  3. +184
    -110
      fastNLP/core/callbacks/checkpoint_callback.py
  4. +1
    -1
      fastNLP/core/callbacks/load_best_model_callback.py
  5. +9
    -9
      fastNLP/core/controllers/evaluator.py
  6. +19
    -3
      fastNLP/core/controllers/trainer.py
  7. +15
    -5
      fastNLP/core/dataset/dataset.py
  8. +1
    -4
      fastNLP/core/dataset/field.py
  9. +11
    -6
      fastNLP/core/drivers/driver.py
  10. +2
    -2
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  11. +0
    -2
      fastNLP/core/drivers/torch_driver/torch_driver.py
  12. +3
    -0
      fastNLP/core/utils/utils.py
  13. +1
    -1
      fastNLP/envs/set_env_on_import.py
  14. +68
    -65
      tests/core/callbacks/test_checkpoint_callback_torch.py
  15. +14
    -0
      tests/core/dataset/test_dataset.py

+ 3
- 2
fastNLP/core/callbacks/__init__.py View File

@@ -4,7 +4,8 @@ __all__ = [
'EventsList', 'EventsList',
'Filter', 'Filter',
'CallbackManager', 'CallbackManager',
'CheckpointCallback',
'ModelCheckpointCallback',
'TrainerCheckpointCallback',
'choose_progress_callback', 'choose_progress_callback',
'ProgressCallback', 'ProgressCallback',
'RichCallback', 'RichCallback',
@@ -16,7 +17,7 @@ __all__ = [
from .callback import Callback from .callback import Callback
from .callback_events import EventsList, Events, Filter from .callback_events import EventsList, Events, Filter
from .callback_manager import CallbackManager from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
from .lr_scheduler_callback import LRSchedCallback from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback from .load_best_model_callback import LoadBestModelCallback


+ 3
- 3
fastNLP/core/callbacks/callback_manager.py View File

@@ -8,7 +8,7 @@ __all__ = [


from .callback_events import Events from .callback_events import Events
from .callback import Callback from .callback import Callback
from .checkpoint_callback import CheckpointCallback
from .checkpoint_callback import TrainerCheckpointCallback
from .progress_callback import ProgressCallback, choose_progress_callback from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -98,7 +98,7 @@ class CallbackManager:
:return: :return:
""" """
for each_callback in self.class_callbacks: for each_callback in self.class_callbacks:
if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint:
if isinstance(each_callback, TrainerCheckpointCallback):
self._has_trainer_checkpoint = True self._has_trainer_checkpoint = True
self.dissect_one_callback(each_callback) self.dissect_one_callback(each_callback)


@@ -210,7 +210,7 @@ class CallbackManager:
each_callback.on_load_checkpoint(trainer, None) each_callback.on_load_checkpoint(trainer, None)


@property @property
def has_trainer_chechpoint(self) -> bool:
def has_trainer_checkpoint(self) -> bool:
return self._has_trainer_checkpoint return self._has_trainer_checkpoint


@_transfer @_transfer


+ 184
- 110
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -1,12 +1,13 @@
__all__ = [
'ModelCheckpointCallback',
'TrainerCheckpointCallback'
]
import os import os
from typing import Union, Optional, Callable, Dict, Sequence
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping
from pathlib import Path from pathlib import Path
from functools import partial
from time import sleep
from abc import ABC
import sys


__all__ = [
'CheckpointCallback'
]


import fastNLP import fastNLP
from .callback import Callback, Filter from .callback import Callback, Filter
@@ -14,35 +15,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
from fastNLP.core.utils import apply_to_collection




class CheckpointCallback(Callback):
class CanItemDataType(ABC):
""" """
1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情;
2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar;
3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1",
"acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例;
4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样;
检测可以进行传输的对象。

""" """


@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented



class CheckpointCallback(Callback):
def __init__( def __init__(
self, self,
monitor, monitor,
is_trainer_checkpoint: Optional[bool] = False,

save_folder: Optional[Union[str, Path]] = None, save_folder: Optional[Union[str, Path]] = None,

save_every_n_epochs: Optional[int] = None, save_every_n_epochs: Optional[int] = None,
save_every_n_global_batches: Optional[int] = None,
save_every_n_batches: Optional[int] = None,
save_last: bool = True, save_last: bool = True,
save_topk: Optional[int] = None, save_topk: Optional[int] = None,
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,

larger_better: bool = True, larger_better: bool = True,
only_state_dict: bool = True, only_state_dict: bool = True,

model_save_fn: Optional[Callable] = None, model_save_fn: Optional[Callable] = None,

**kwargs, **kwargs,
): ):
if monitor is None and save_topk is not None: if monitor is None and save_topk is not None:
@@ -51,9 +54,6 @@ class CheckpointCallback(Callback):
if monitor is not None and not isinstance(monitor, str): if monitor is not None and not isinstance(monitor, str):
raise ValueError("Parameter `monitor` should be of 'str' type.") raise ValueError("Parameter `monitor` should be of 'str' type.")


if not isinstance(is_trainer_checkpoint, bool):
raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.")

if save_folder is None: if save_folder is None:
logger.warning( logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.") "Parameter `path` is None, and we will use the current work directory to find and load your model.")
@@ -67,15 +67,15 @@ class CheckpointCallback(Callback):
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1:
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.")


# 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的
# 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样;
self._filter_every_n_epochs = Filter(every=save_every_n_epochs)
else:
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除


if save_every_n_global_batches is not None:
if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1:
if save_every_n_batches is not None:
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1:
raise ValueError( raise ValueError(
"parameter save_every_n_global_batches should be an int and greater than or equal to 1.")
self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches)
"parameter save_every_n_batches should be an int and greater than or equal to 1.")
else:
save_every_n_batches = sys.maxsize # 使得没有数字可以整除


if save_topk is not None: if save_topk is not None:
if not isinstance(save_topk, int) or save_topk < 1: if not isinstance(save_topk, int) or save_topk < 1:
@@ -89,12 +89,12 @@ class CheckpointCallback(Callback):
if not issubclass(exception, BaseException): if not issubclass(exception, BaseException):
raise TypeError("Each exception in parameter `save_on_exception` can only be " raise TypeError("Each exception in parameter `save_on_exception` can only be "
"`BaseException` type.") "`BaseException` type.")

else:
save_on_exception = []
self.monitor = monitor self.monitor = monitor
self.is_trainer_checkpoint = is_trainer_checkpoint
self.save_folder = Path(save_folder) self.save_folder = Path(save_folder)
self.save_every_n_epochs = save_every_n_epochs self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_global_batches = save_every_n_global_batches
self.save_every_n_batches = save_every_n_batches
self.save_last = save_last self.save_last = save_last
self.save_topk = save_topk self.save_topk = save_topk
self.larger_better = larger_better self.larger_better = larger_better
@@ -107,7 +107,7 @@ class CheckpointCallback(Callback):
self._topk_model = {} self._topk_model = {}
self._topn = 0 # 表示目前已经保存了几个最好的模型; self._topn = 0 # 表示目前已经保存了几个最好的模型;


# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为;
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key;
@@ -115,76 +115,83 @@ class CheckpointCallback(Callback):


# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
synchronize_mkdir(self.log_filepath)
synchronize_mkdir(self.timestamp_path)


def on_validate_end(self, trainer, validate_res): def on_validate_end(self, trainer, validate_res):
self._save_topk(trainer, validate_res) self._save_topk(trainer, validate_res)


def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
self._save_every_n_epochs(trainer)
self._save_last(trainer)
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}'
self.save(trainer, folder_name=folder_name)
if self.save_last:
folder_name = f'{self.folder_prefix}-last'
self.save(trainer, folder_name=folder_name)


def on_train_batch_end(self, trainer): def on_train_batch_end(self, trainer):
self._save_every_n_global_batches(trainer)
if trainer.global_forward_batches % self.save_every_n_batches == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}'
self.save(trainer, folder_name=folder_name)


def on_exception(self, trainer, exception: BaseException): def on_exception(self, trainer, exception: BaseException):
if self.save_on_exception is not None and exception.__class__ in self.save_on_exception:
folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None)
folder = folder + f"_{exception.__class__.__name__}"
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder)
if exception.__class__ in self.save_on_exception:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \
f'exception_{exception.__class__.__name__}'
self.save(trainer=trainer, folder_name=folder_name)


def on_sanity_check_end(self, trainer, sanity_check_res): def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
self._get_validate_metric(sanity_check_res) self._get_validate_metric(sanity_check_res)


def on_save_checkpoint(self, trainer) -> Dict: def on_save_checkpoint(self, trainer) -> Dict:
""" """
我们需要保存 CheckpointCallback 内部的几个 filter 的状态;
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。
topk_model的状态
_real_monitor的值
""" """

states = {} states = {}
if self.save_every_n_epochs is not None:
states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict()
if self.save_every_n_global_batches is not None:
states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict()
states["real_monitor"] = self._real_monitor
states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType,
function=lambda x:x.item())
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
states['_real_monitor'] = self._real_monitor
return states return states


def on_load_checkpoint(self, trainer, states: Optional[Dict]): def on_load_checkpoint(self, trainer, states: Optional[Dict]):
if self.save_every_n_epochs is not None:
self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"])
if self.save_every_n_global_batches is not None:
self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"])
timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to save in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)
_topk_model = states['_topk_model']
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
if save_topk is not None and self.save_topk is not None:
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
f"as {save_topk}."
self._topk_model.update(self._topk_model)
self._real_monitor = states["real_monitor"] self._real_monitor = states["real_monitor"]


def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"):
if self.save_every_n_epochs is not None:
if self.is_trainer_checkpoint:
_fn_every_n_epochs = trainer.save
else:
_fn_every_n_epochs = trainer.save_model
_fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None)
_fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs)
_fn_every_n_epochs()

def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"):
if self.save_every_n_global_batches is not None:
if self.is_trainer_checkpoint:
_fn_every_n_global_batches = trainer.save
else:
_fn_every_n_global_batches = trainer.save_model
_fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None)
_fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches)
_fn_every_n_global_batches()

def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict):
"""
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。

:param trainer:
:param validate_res:
:return:
"""
if self.save_topk is not None: if self.save_topk is not None:
_metric_value = self._get_validate_metric(validate_res) _metric_value = self._get_validate_metric(validate_res)
_saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value)
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self._real_monitor}_{_metric_value}"


_should_save = False _should_save = False
if self._topn < self.save_topk: if self._topn < self.save_topk:
self._topk_model[_saved_name] = _metric_value
self._topk_model[folder_name] = _metric_value
self._topn += 1 self._topn += 1
_should_save = True _should_save = True
else: else:
@@ -192,39 +199,27 @@ class CheckpointCallback(Callback):
key=lambda x: self._topk_model[x]) key=lambda x: self._topk_model[x])
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]):
self._topk_model[_saved_name] = _metric_value
self._topk_model[folder_name] = _metric_value
_should_save = True _should_save = True
self._topk_model.pop(_least_valuable_model) self._topk_model.pop(_least_valuable_model)
synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model))
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model))


assert len(self._topk_model) == self.save_topk == self._topn assert len(self._topk_model) == self.save_topk == self._topn


if _should_save: if _should_save:
self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name)
self.save(trainer, folder_name=folder_name)


def _save_last(self, trainer: "fastNLP.Trainer"):
if self.save_last:
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last")

def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None,
substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None):
# 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者
# epoch-batch-monitor-monitor_value;
if substitute_folder is None:
folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric))
else:
folder = self.log_filepath.joinpath(substitute_folder)
def save(self, trainer, folder_name):
"""
执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。


:param trainer:
:param folder_name:
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
synchronize_mkdir(folder) synchronize_mkdir(folder)

# 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数;
if substitute_fn is not None:
_fn = substitute_fn
else:
if self.is_trainer_checkpoint:
_fn = trainer.save
else:
_fn = trainer.save_model
_fn = getattr(trainer, self.save_fn_name)
_fn( _fn(
folder=folder, folder=folder,
only_state_dict=self.only_state_dict, only_state_dict=self.only_state_dict,
@@ -243,18 +238,95 @@ class CheckpointCallback(Callback):
self._real_monitor = use_monitor self._real_monitor = use_monitor
return value return value


def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False,
metric: Optional[Union[int, float]] = None) -> str:
@property
def folder_prefix(self):
raise NotImplementedError("The `folder_prefix` is not specified")

@property
def save_fn_name(self):
raise NotImplementedError("The `save_fn_name` is not specified.")


class ModelCheckpointCallback(CheckpointCallback):
"""
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下

- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- model-last/ # 最后一个 epoch 的保存
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名

model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
@property
def save_fn_name(self):
return 'save_model'

@property
def callback_name(self):
""" """
获取当前保存模型的真正地名字;
metric 参数仅当 mode 为 'topk' 时起作用;
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态
:return:
""" """
cur_epoch_idx = trainer.cur_epoch_idx
global_forward_batches = trainer.global_forward_batches
_other = ""
if topk:
_other = f"_{metric}"
return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}"
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"

@property
def folder_prefix(self):
return 'model'


class TrainerCheckpointCallback(CheckpointCallback):
"""
保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下

- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- trainer-last/ # 最后一个 epoch 的保存
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名

model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。

:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。
:param save_every_n_batches: 多少个 batch 保存一次。
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
:param save_topk: 保存 monitor 结果 topK 个。
:param save_on_exception: 在出异常信息时,是否保存。
:param larger_better: monitor 的值是否时越大越好。
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
:param kwargs:
"""
@property
def save_fn_name(self):
return 'save'


@property @property
def callback_name(self): def callback_name(self):
@@ -262,6 +334,8 @@ class CheckpointCallback(Callback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
:return: :return:
""" """
return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}"

return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"


@property
def folder_prefix(self):
return 'trainer'

+ 1
- 1
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback):
请在函数内完成对模型的保存。 请在函数内完成对模型的保存。
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, :param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出,
请在函数内完成对模型的加载。 请在函数内完成对模型的加载。
:param delete_after_train: 在加载了最佳模型之后是否删掉模型。
:param delete_after_train: 在训练结束后是否删掉模型。
""" """
if model_load_fn is not None: if model_load_fn is not None:
assert callable(model_load_fn), "`model_load_fn` must be a callable object." assert callable(model_load_fn), "`model_load_fn` must be a callable object."


+ 9
- 9
fastNLP/core/controllers/evaluator.py View File

@@ -133,17 +133,18 @@ class Evaluator:


self.driver.barrier() self.driver.barrier()


def run(self, num_eval_batch_per_dl: int = -1) -> Dict:
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
""" """
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。
如果存在多个metric,一个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name
如果存在多个数据集,一个metric的情况,key的命名规则是
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
如果存在多个metric,多个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name#dataloader_name
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据
如果存在多个metric,一个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name
如果存在多个数据集,一个metric的情况,key的命名规则是
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
如果存在多个metric,多个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name#dataloader_name
其中 metric_indicator_name 可能不存在


:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。
:return: :return:
""" """
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
@@ -157,7 +158,6 @@ class Evaluator:
assert self.driver.has_test_dataloaders() assert self.driver.has_test_dataloaders()


metric_results = {} metric_results = {}

self.reset() self.reset()
evaluate_context = self.driver.get_evaluate_context() evaluate_context = self.driver.get_evaluate_context()
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train')


+ 19
- 3
fastNLP/core/controllers/trainer.py View File

@@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger):
self.driver.set_deterministic_dataloader(self.dataloader) self.driver.set_deterministic_dataloader(self.dataloader)


self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager.has_trainer_chechpoint)
reproducible=self.callback_manager.has_trainer_checkpoint)


self.set_grad_to_none = kwargs.get("set_grad_to_none", True) self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
self.on_after_trainer_initialized(self.driver) self.on_after_trainer_initialized(self.driver)
@@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger):
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")


if self.evaluator is not None and num_eval_sanity_batch > 0: if self.evaluator is not None and num_eval_sanity_batch > 0:
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.")
self.on_sanity_check_begin() self.on_sanity_check_begin()
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch)
self.on_sanity_check_end(sanity_check_res) self.on_sanity_check_end(sanity_check_res)
@@ -509,7 +510,7 @@ class Trainer(TrainerEventTrigger):


:param folder: 保存模型的地址; :param folder: 保存模型的地址;
:param only_state_dict: 是否只保存模型的 `state_dict`; :param only_state_dict: 是否只保存模型的 `state_dict`;
:param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它; :param kwargs: 一些 driver 的保存模型的函数的参数另有其它;
""" """


@@ -534,7 +535,16 @@ class Trainer(TrainerEventTrigger):


def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False, def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False,
model_load_fn: Optional[Callable] = None, **kwargs): model_load_fn: Optional[Callable] = None, **kwargs):
"""
加载模型


:param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时,
直接将该 folder 传递到 model_load_fn 中。
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。
:param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。
:param kwargs:
:return:
"""
self.on_load_model() self.on_load_model()
self.driver.barrier() self.driver.barrier()
if not isinstance(folder, (io.BytesIO, BinaryIO)): if not isinstance(folder, (io.BytesIO, BinaryIO)):
@@ -555,7 +565,13 @@ class Trainer(TrainerEventTrigger):


def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs):
r""" r"""
用于断点重训的保存函数;
用于断点重训 Trainer 的保存函数;

:param folder:
:param only_state_dict:
:param model_save_fn:
:param kwargs:
:return:
""" """
self.driver.barrier() self.driver.barrier()




+ 15
- 5
fastNLP/core/dataset/dataset.py View File

@@ -8,9 +8,8 @@ __all__ = [


import _pickle as pickle import _pickle as pickle
from copy import deepcopy from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from functools import partial from functools import partial
import warnings


import numpy as np import numpy as np
from threading import Thread from threading import Thread
@@ -197,6 +196,20 @@ class DataSet:
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))


def __setitem__(self, key, value):
assert isinstance(key, int) and key<len(self)
assert isinstance(value, Instance) or isinstance(value, Mapping)
ins_keys = set(value.keys())
ds_keys = set(self.get_field_names())

if len(ins_keys - ds_keys) != 0:
raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.")
if len(ds_keys - ins_keys) != 0:
raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.")

for field_name, field in self.field_arrays.items():
field[key] = value[field_name]

def __getattribute__(self, item): def __getattribute__(self, item):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)


@@ -813,6 +826,3 @@ class DataSet:
self.collate_fns.set_input(*field_names) self.collate_fns.set_input(*field_names)




class IterableDataset:
pass


+ 1
- 4
fastNLP/core/dataset/field.py View File

@@ -46,9 +46,6 @@ class FieldArray:


def __setitem__(self, idx: int, val: Any): def __setitem__(self, idx: int, val: Any):
assert isinstance(idx, int) assert isinstance(idx, int)
if idx == -1:
idx = len(self) - 1
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}"
self.content[idx] = val self.content[idx] = val


def get(self, indices: Union[int, List[int]]): def get(self, indices: Union[int, List[int]]):
@@ -79,7 +76,7 @@ class FieldArray:


def split(self, sep: str = None, inplace: bool = True): def split(self, sep: str = None, inplace: bool = True):
r""" r"""
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。


:param sep: 分割符,如果为None则直接调用str.split()。 :param sep: 分割符,如果为None则直接调用str.split()。
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 :param inplace: 如果为True,则将新生成值替换本field。否则返回list。


+ 11
- 6
fastNLP/core/drivers/driver.py View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from io import BytesIO from io import BytesIO
import json


__all__ = [ __all__ = [
'Driver' 'Driver'
@@ -68,9 +69,12 @@ class Driver(ABC):
def set_sampler_epoch(self, dataloader, cur_epoch_idx): def set_sampler_epoch(self, dataloader, cur_epoch_idx):
r""" r"""
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; 对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的;
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。


:param dataloader: 需要设置 epoch 的 dataloader 。
:param cur_epoch_idx: 当前是第几个 epoch; :param cur_epoch_idx: 当前是第几个 epoch;
""" """

@abstractmethod @abstractmethod
def train_step(self, batch): def train_step(self, batch):
""" """
@@ -444,13 +448,14 @@ class Driver(ABC):


exc_type, exc_value, exc_traceback_obj = sys.exc_info() exc_type, exc_value, exc_traceback_obj = sys.exc_info()
_write_exc_info = { _write_exc_info = {
'exc_type': exc_type,
'exc_value': exc_value,
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
'global_rank': getattr(self, "global_rank", None),
'rank': self.get_local_rank(),
'exc_type': str(exc_type.__name__),
'exc_value': str(exc_value),
'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
'exc_global_rank': getattr(self, "global_rank", None),
'exc_local_rank': self.get_local_rank(),
} }
sys.stderr.write(str(_write_exc_info)+"\n")
sys.stderr.write("\nException info:\n")
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n")


sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n")
for pid in self._pids: for pid in self._pids:


+ 2
- 2
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
# world_size 和 rank # world_size 和 rank
if FASTNLP_BACKEND_LAUNCH in os.environ: if FASTNLP_BACKEND_LAUNCH in os.environ:
if device is not None: if device is not None:
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
"up your script. And we will directly get the local device via " "up your script. And we will directly get the local device via "
"`os.environ['LOCAL_RANK']`.") "`os.environ['LOCAL_RANK']`.")
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)
@@ -65,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
if not isinstance(device, List): if not isinstance(device, List):
return TorchSingleDriver(model, device, **kwargs) return TorchSingleDriver(model, device, **kwargs)
else: else:
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" "`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter"
"`driver` as `TorchDDPDriver`.") "`driver` as `TorchDDPDriver`.")
return TorchDDPDriver(model, device, **kwargs) return TorchDDPDriver(model, device, **kwargs)


+ 0
- 2
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -143,8 +143,6 @@ class TorchDriver(Driver):


:param filepath: 保存到哪个文件夹; :param filepath: 保存到哪个文件夹;
:param only_state_dict: 是否只保存权重; :param only_state_dict: 是否只保存权重;
:param model_save_fn:

:return: :return:
""" """
model = self.unwrap_model() model = self.unwrap_model()


+ 3
- 0
fastNLP/core/utils/utils.py View File

@@ -44,6 +44,9 @@ __all__ = [
] ]







def get_fn_arg_names(fn: Callable) -> List[str]: def get_fn_arg_names(fn: Callable) -> List[str]:
r""" r"""
返回一个函数的所有参数的名字; 返回一个函数的所有参数的名字;


+ 1
- 1
fastNLP/envs/set_env_on_import.py View File

@@ -65,7 +65,7 @@ def set_env_on_import():


# fastNLP 内部使用的一些变量 # fastNLP 内部使用的一些变量
if FASTNLP_LAUNCH_TIME not in os.environ: if FASTNLP_LAUNCH_TIME not in os.environ:
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}"
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}"
os.environ[FASTNLP_LAUNCH_TIME] = cur_time os.environ[FASTNLP_LAUNCH_TIME] = cur_time


# 设置对应的值 # 设置对应的值


+ 68
- 65
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -8,7 +8,7 @@ import torch.distributed as dist
from pathlib import Path from pathlib import Path
import re import re


from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK


@@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1(
version, version,
only_state_dict only_state_dict
): ):
# def test_model_checkpoint_callback_1(
# model_and_optimizers: TrainerParameters,
# driver='torch_ddp',
# device=[0, 1],
# version=1,
# only_state_dict=True
# ):
path = Path.cwd().joinpath(f"test_model_checkpoint") path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True) path.mkdir(exist_ok=True, parents=True)


if version == 0: if version == 0:
callbacks = [ callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc", monitor="acc",
save_folder=path, save_folder=path,
save_every_n_epochs=1, save_every_n_epochs=1,
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None, save_topk=None,
save_last=False, save_last=False,
save_on_exception=None, save_on_exception=None,
@@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1(
] ]
elif version == 1: elif version == 1:
callbacks = [ callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc", monitor="acc",
save_folder=path, save_folder=path,
save_every_n_epochs=3, save_every_n_epochs=3,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=2, save_topk=2,
save_last=True, save_last=True,
save_on_exception=None, save_on_exception=None,
@@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1(
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,

n_epochs=10, n_epochs=10,
callbacks=callbacks, callbacks=callbacks,
output_from_new_proc="all" output_from_new_proc="all"
@@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1(
if version == 0: if version == 0:


if driver == "torch": if driver == "torch":
assert "epoch_10-global_batch_250-acc" in all_saved_model_paths
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
assert "model-epoch_10" in all_saved_model_paths
assert "model-epoch_4-batch_123" in all_saved_model_paths


epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"]
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["model-epoch_10"]
step_save_path = all_saved_model_paths["model-epoch_4-batch_123"]


assert len(all_saved_model_paths) == 12 assert len(all_saved_model_paths) == 12
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else: else:
assert "epoch_6-global_batch_78-acc" in all_saved_model_paths
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
assert "model-epoch_6" in all_saved_model_paths
assert "model-epoch_9-batch_123" in all_saved_model_paths


epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"]
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["model-epoch_6"]
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"]


assert len(all_saved_model_paths) == 11 assert len(all_saved_model_paths) == 11
all_state_dicts = [epoch_save_path, step_save_path] all_state_dicts = [epoch_save_path, step_save_path]


elif version == 1: elif version == 1:


pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")


if driver == "torch": if driver == "torch":
assert "epoch_9-global_batch_225-acc" in all_saved_model_paths
assert "last" in all_saved_model_paths
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths
aLL_topk_folders = [] aLL_topk_folders = []
for each_folder_name in all_saved_model_paths: for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name) each_folder_name = pattern.findall(each_folder_name)
@@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0]) aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2 assert len(aLL_topk_folders) == 2


epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"]
last_save_path = all_saved_model_paths["last"]
epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]


assert len(all_saved_model_paths) == 6 assert len(all_saved_model_paths) == 6
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else: else:
assert "epoch_9-global_batch_117-acc" in all_saved_model_paths
assert "last" in all_saved_model_paths
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths


aLL_topk_folders = [] aLL_topk_folders = []
for each_folder_name in all_saved_model_paths: for each_folder_name in all_saved_model_paths:
@@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0]) aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2 assert len(aLL_topk_folders) == 2


epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"]
last_save_path = all_saved_model_paths["last"]
epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]


assert len(all_saved_model_paths) == 6 assert len(all_saved_model_paths) == 6
@@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1(


finally: finally:
synchronize_safe_rm(path) synchronize_safe_rm(path)
# pass
pass


if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()
@@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2(
raise NotImplementedError raise NotImplementedError


callbacks = [ callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc1", monitor="acc1",
save_folder=path, save_folder=path,
save_every_n_epochs=None, save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=None, save_topk=None,
save_last=False, save_last=False,
save_on_exception=NotImplementedError, save_on_exception=NotImplementedError,
@@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2(
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}


if driver == "torch": if driver == "torch":
assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"]
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"]
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else: else:
assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"]
assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"]


assert len(all_saved_model_paths) == 1 assert len(all_saved_model_paths) == 1
all_state_dicts = [exception_model_path] all_state_dicts = [exception_model_path]
@@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1(


if version == 0: if version == 0:
callbacks = [ callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc", monitor="acc",
is_trainer_checkpoint=True,
save_folder=path, save_folder=path,
save_every_n_epochs=7, save_every_n_epochs=7,
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None, save_topk=None,
save_last=False, save_last=False,
save_on_exception=None, save_on_exception=None,
@@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1(
] ]
elif version == 1: elif version == 1:
callbacks = [ callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc", monitor="acc",
is_trainer_checkpoint=True,
save_folder=path, save_folder=path,
save_every_n_epochs=None, save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=2, save_topk=2,
save_last=True, save_last=True,
save_on_exception=None, save_on_exception=None,
@@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1(
if version == 0: if version == 0:


if driver == "torch": if driver == "torch":
assert "epoch_7-global_batch_175-acc" in all_saved_model_paths
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_4-batch_123" in all_saved_model_paths


epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"]
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"]


assert len(all_saved_model_paths) == 3 assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else: else:
assert "epoch_7-global_batch_91-acc" in all_saved_model_paths
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_9-batch_123" in all_saved_model_paths


epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"]
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"]


assert len(all_saved_model_paths) == 2 assert len(all_saved_model_paths) == 2
all_state_dicts = [epoch_save_path, step_save_path] all_state_dicts = [epoch_save_path, step_save_path]


elif version == 1: elif version == 1:


pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")


# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if driver == "torch": if driver == "torch":
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = [] aLL_topk_folders = []
for each_folder_name in all_saved_model_paths: for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name) each_folder_name = pattern.findall(each_folder_name)
@@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0]) aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2 assert len(aLL_topk_folders) == 2


last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]


assert len(all_saved_model_paths) == 3 assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else: else:
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths


aLL_topk_folders = [] aLL_topk_folders = []
for each_folder_name in all_saved_model_paths: for each_folder_name in all_saved_model_paths:
@@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0]) aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2 assert len(aLL_topk_folders) == 2


last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]


assert len(all_saved_model_paths) == 3 assert len(all_saved_model_paths) == 3
@@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2(
device, device,
version version
): ):
pytest.skip("Skip transformers test for now.")
path = Path.cwd().joinpath(f"test_model_checkpoint") path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True) path.mkdir(exist_ok=True, parents=True)


import transformers
import transformers # 版本4.16.2
import torch import torch
from torchmetrics import Accuracy from torchmetrics import Accuracy
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
@@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2(


if version == 0: if version == 0:
callbacks = [ callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc", monitor="acc",
is_trainer_checkpoint=True,
save_folder=path, save_folder=path,
save_every_n_epochs=None, save_every_n_epochs=None,
save_every_n_global_batches=50,
save_every_n_batches=50,
save_topk=None, save_topk=None,
save_last=False, save_last=False,
save_on_exception=None, save_on_exception=None,
@@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2(
] ]
elif version == 1: elif version == 1:
callbacks = [ callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc", monitor="acc",
is_trainer_checkpoint=True,
save_folder=path, save_folder=path,
save_every_n_epochs=None, save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=1, save_topk=1,
save_last=True, save_last=True,
save_on_exception=None, save_on_exception=None,
@@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2(
if version == 0: if version == 0:


if driver == "torch": if driver == "torch":
assert "epoch_1-global_batch_200-acc" in all_saved_model_paths
assert "trainer-epoch_1-batch_200" in all_saved_model_paths


epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"]


assert len(all_saved_model_paths) == 4 assert len(all_saved_model_paths) == 4
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else: else:
assert "epoch_1-global_batch_100-acc" in all_saved_model_paths
assert "trainer-epoch_1-batch_100" in all_saved_model_paths


epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"]


assert len(all_saved_model_paths) == 2 assert len(all_saved_model_paths) == 2
all_state_dicts = [epoch_save_path] all_state_dicts = [epoch_save_path]


elif version == 1: elif version == 1:


pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")


# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if driver == "torch": if driver == "torch":
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = [] aLL_topk_folders = []
for each_folder_name in all_saved_model_paths: for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name) each_folder_name = pattern.findall(each_folder_name)
@@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2(
aLL_topk_folders.append(each_folder_name[0]) aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 1 assert len(aLL_topk_folders) == 1


last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]


assert len(all_saved_model_paths) == 2 assert len(all_saved_model_paths) == 2
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else: else:
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths


aLL_topk_folders = [] aLL_topk_folders = []
for each_folder_name in all_saved_model_paths: for each_folder_name in all_saved_model_paths:
@@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2(
aLL_topk_folders.append(each_folder_name[0]) aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 1 assert len(aLL_topk_folders) == 1


last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]


assert len(all_saved_model_paths) == 2 assert len(all_saved_model_paths) == 2


+ 14
- 0
tests/core/dataset/test_dataset.py View File

@@ -105,6 +105,20 @@ class TestDataSetMethods(unittest.TestCase):
self.assertTrue(isinstance(field_array, FieldArray)) self.assertTrue(isinstance(field_array, FieldArray))
self.assertEqual(len(field_array), 40) self.assertEqual(len(field_array), 40)


def test_setitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.add_field('i', list(range(len(ds))))
assert ds.get_field('i').content == list(range(len(ds)))
import random
random.shuffle(ds)
import numpy as np
np.random.shuffle(ds)
assert ds.get_field('i').content != list(range(len(ds)))

ins1 = ds[1]
ds[2] = ds[1]
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']

def test_get_item_error(self): def test_get_item_error(self):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})


Loading…
Cancel
Save