From ce835212e6613f42f4acee50a17799f602d85a63 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 11 Apr 2022 14:34:11 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=B0=86CheckpointCallback=E6=8B=86?= =?UTF-8?q?=E5=88=86=E4=B8=BAModelCheckpointCallback=E5=92=8CTrainerCheckp?= =?UTF-8?q?ointCallback=EF=BC=8C=E4=BF=AE=E6=94=B9=E4=BA=86=E9=83=A8?= =?UTF-8?q?=E5=88=86=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_manager.py | 6 +- fastNLP/core/callbacks/checkpoint_callback.py | 293 +++++++++++------- .../callbacks/load_best_model_callback.py | 2 +- fastNLP/core/controllers/trainer.py | 21 +- fastNLP/core/drivers/driver.py | 3 + .../core/drivers/torch_driver/torch_driver.py | 2 - fastNLP/core/utils/utils.py | 3 + fastNLP/envs/set_env_on_import.py | 2 +- .../test_checkpoint_callback_torch.py | 133 ++++---- 9 files changed, 280 insertions(+), 185 deletions(-) diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index c239f8b1..8b53c70b 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -8,7 +8,7 @@ __all__ = [ from .callback_events import Events from .callback import Callback -from .checkpoint_callback import CheckpointCallback +from .checkpoint_callback import TrainerCheckpointCallback from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.core.log import logger @@ -98,7 +98,7 @@ class CallbackManager: :return: """ for each_callback in self.class_callbacks: - if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint: + if isinstance(each_callback, TrainerCheckpointCallback): self._has_trainer_checkpoint = True self.dissect_one_callback(each_callback) @@ -210,7 +210,7 @@ class CallbackManager: each_callback.on_load_checkpoint(trainer, None) @property - def has_trainer_chechpoint(self) -> bool: + def has_trainer_checkpoint(self) -> bool: return self._has_trainer_checkpoint @_transfer diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 5fcc7e26..5cd102e0 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -1,12 +1,12 @@ -import os -from typing import Union, Optional, Callable, Dict, Sequence -from pathlib import Path -from functools import partial -from time import sleep - __all__ = [ 'CheckpointCallback' ] +import os +from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping +from pathlib import Path +from abc import ABC +import sys + import fastNLP from .callback import Callback, Filter @@ -14,35 +14,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir +from fastNLP.core.utils import apply_to_collection -class CheckpointCallback(Callback): +class CanItemDataType(ABC): """ - 1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情; - 2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar; - 3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1", - "acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例; - 4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样; + 检测可以进行传输的对象。 + """ + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is CanItemDataType: + item = getattr(subclass, 'item', None) + return callable(item) + return NotImplemented + + + +class CheckpointCallback(Callback): def __init__( self, monitor, - is_trainer_checkpoint: Optional[bool] = False, - save_folder: Optional[Union[str, Path]] = None, - save_every_n_epochs: Optional[int] = None, - save_every_n_global_batches: Optional[int] = None, + save_every_n_batches: Optional[int] = None, save_last: bool = True, save_topk: Optional[int] = None, save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, - larger_better: bool = True, only_state_dict: bool = True, - model_save_fn: Optional[Callable] = None, - **kwargs, ): if monitor is None and save_topk is not None: @@ -51,9 +53,6 @@ class CheckpointCallback(Callback): if monitor is not None and not isinstance(monitor, str): raise ValueError("Parameter `monitor` should be of 'str' type.") - if not isinstance(is_trainer_checkpoint, bool): - raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.") - if save_folder is None: logger.warning( "Parameter `path` is None, and we will use the current work directory to find and load your model.") @@ -67,15 +66,15 @@ class CheckpointCallback(Callback): if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") - # 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的 - # 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样; - self._filter_every_n_epochs = Filter(every=save_every_n_epochs) + else: + save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 - if save_every_n_global_batches is not None: - if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1: + if save_every_n_batches is not None: + if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: raise ValueError( - "parameter save_every_n_global_batches should be an int and greater than or equal to 1.") - self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches) + "parameter save_every_n_batches should be an int and greater than or equal to 1.") + else: + save_every_n_batches = sys.maxsize # 使得没有数字可以整除 if save_topk is not None: if not isinstance(save_topk, int) or save_topk < 1: @@ -89,12 +88,12 @@ class CheckpointCallback(Callback): if not issubclass(exception, BaseException): raise TypeError("Each exception in parameter `save_on_exception` can only be " "`BaseException` type.") - + else: + save_on_exception = [] self.monitor = monitor - self.is_trainer_checkpoint = is_trainer_checkpoint self.save_folder = Path(save_folder) self.save_every_n_epochs = save_every_n_epochs - self.save_every_n_global_batches = save_every_n_global_batches + self.save_every_n_batches = save_every_n_batches self.save_last = save_last self.save_topk = save_topk self.larger_better = larger_better @@ -107,7 +106,7 @@ class CheckpointCallback(Callback): self._topk_model = {} self._topn = 0 # 表示目前已经保存了几个最好的模型; - # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个 + # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; @@ -115,76 +114,83 @@ class CheckpointCallback(Callback): # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; - self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) + self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; - synchronize_mkdir(self.log_filepath) + synchronize_mkdir(self.timestamp_path) def on_validate_end(self, trainer, validate_res): self._save_topk(trainer, validate_res) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): - self._save_every_n_epochs(trainer) - self._save_last(trainer) + if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: + folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' + self.save(trainer, folder_name=folder_name) + if self.save_last: + folder_name = f'{self.folder_prefix}-last' + self.save(trainer, folder_name=folder_name) def on_train_batch_end(self, trainer): - self._save_every_n_global_batches(trainer) + if trainer.global_forward_batches % self.save_every_n_batches == 0: + folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' + self.save(trainer, folder_name=folder_name) def on_exception(self, trainer, exception: BaseException): - if self.save_on_exception is not None and exception.__class__ in self.save_on_exception: - folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None) - folder = folder + f"_{exception.__class__.__name__}" - self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder) + if exception.__class__ in self.save_on_exception: + folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ + f'exception_{exception.__class__.__name__}' + self.save(trainer=trainer, folder_name=folder_name) def on_sanity_check_end(self, trainer, sanity_check_res): + # 主要核对一下 monitor 是否存在。 self._get_validate_metric(sanity_check_res) def on_save_checkpoint(self, trainer) -> Dict: """ - 我们需要保存 CheckpointCallback 内部的几个 filter 的状态; + 保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 + topk_model的状态 + _real_monitor的值 """ + states = {} - if self.save_every_n_epochs is not None: - states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict() - if self.save_every_n_global_batches is not None: - states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict() - states["real_monitor"] = self._real_monitor + states['timestamp_path'] = str(self.timestamp_path.absolute()) + states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, + function=lambda x:x.item()) + states['save_topk'] = 0 if self.save_topk is None else self.save_topk + states['_real_monitor'] = self._real_monitor return states def on_load_checkpoint(self, trainer, states: Optional[Dict]): - if self.save_every_n_epochs is not None: - self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"]) - if self.save_every_n_global_batches is not None: - self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"]) + timestamp_path = states['timestamp_path'] + if not os.path.exists(timestamp_path): + logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " + f" {self.timestamp_path.absolute()}.") + else: + logger.info(f"Resume to save in path: {timestamp_path}.") + self.timestamp_path = Path(timestamp_path) + _topk_model = states['_topk_model'] + save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) + if save_topk is not None and self.save_topk is not None: + assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ + f"as {save_topk}." + self._topk_model.update(self._topk_model) self._real_monitor = states["real_monitor"] - def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"): - if self.save_every_n_epochs is not None: - if self.is_trainer_checkpoint: - _fn_every_n_epochs = trainer.save - else: - _fn_every_n_epochs = trainer.save_model - _fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None) - _fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs) - _fn_every_n_epochs() - - def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"): - if self.save_every_n_global_batches is not None: - if self.is_trainer_checkpoint: - _fn_every_n_global_batches = trainer.save - else: - _fn_every_n_global_batches = trainer.save_model - _fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None) - _fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches) - _fn_every_n_global_batches() - def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): + """ + 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 + + :param trainer: + :param validate_res: + :return: + """ if self.save_topk is not None: _metric_value = self._get_validate_metric(validate_res) - _saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value) + folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ + f"-{self._real_monitor}_{_metric_value}" _should_save = False if self._topn < self.save_topk: - self._topk_model[_saved_name] = _metric_value + self._topk_model[folder_name] = _metric_value self._topn += 1 _should_save = True else: @@ -192,39 +198,27 @@ class CheckpointCallback(Callback): key=lambda x: self._topk_model[x]) if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): - self._topk_model[_saved_name] = _metric_value + self._topk_model[folder_name] = _metric_value _should_save = True self._topk_model.pop(_least_valuable_model) - synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model)) + synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) assert len(self._topk_model) == self.save_topk == self._topn if _should_save: - self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name) + self.save(trainer, folder_name=folder_name) - def _save_last(self, trainer: "fastNLP.Trainer"): - if self.save_last: - self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last") - - def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None, - substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None): - # 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者 - # epoch-batch-monitor-monitor_value; - if substitute_folder is None: - folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric)) - else: - folder = self.log_filepath.joinpath(substitute_folder) + def save(self, trainer, folder_name): + """ + 执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 + :param trainer: + :param folder_name: + :return: + """ + folder = self.timestamp_path.joinpath(folder_name) synchronize_mkdir(folder) - - # 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数; - if substitute_fn is not None: - _fn = substitute_fn - else: - if self.is_trainer_checkpoint: - _fn = trainer.save - else: - _fn = trainer.save_model + _fn = getattr(trainer, self.save_fn_name) _fn( folder=folder, only_state_dict=self.only_state_dict, @@ -243,18 +237,95 @@ class CheckpointCallback(Callback): self._real_monitor = use_monitor return value - def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False, - metric: Optional[Union[int, float]] = None) -> str: + @property + def folder_prefix(self): + raise NotImplementedError("The `folder_prefix` is not specified") + + @property + def save_fn_name(self): + raise NotImplementedError("The `save_fn_name` is not specified.") + + +class ModelCheckpointCallback(CheckpointCallback): + """ + 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 + + - save_folder/ + - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 + - model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 + - model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 + - model-last/ # 最后一个 epoch 的保存 + - model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 + - model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 + + model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 + 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 + + :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。 + :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 + :param save_every_n_epochs: 多少个 epoch 保存一次。 + :param save_every_n_batches: 多少个 batch 保存一次。 + :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 + :param save_topk: 保存 monitor 结果 topK 个。 + :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 + :param larger_better: monitor 的值是否时越大越好。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 + :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 + 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param kwargs: + """ + @property + def save_fn_name(self): + return 'save_model' + + @property + def callback_name(self): """ - 获取当前保存模型的真正地名字; - metric 参数仅当 mode 为 'topk' 时起作用; + 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; + :return: """ - cur_epoch_idx = trainer.cur_epoch_idx - global_forward_batches = trainer.global_forward_batches - _other = "" - if topk: - _other = f"_{metric}" - return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}" + return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + + @property + def folder_prefix(self): + return 'model' + + +class TrainerCheckpointCallback(CheckpointCallback): + """ + 保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 + + - save_folder/ + - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 + - trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 + - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 + - trainer-last/ # 最后一个 epoch 的保存 + - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 + - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 + + model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 + 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 + + :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。 + :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 + :param save_every_n_epochs: 多少个 epoch 保存一次。 + :param save_every_n_batches: 多少个 batch 保存一次。 + :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 + :param save_topk: 保存 monitor 结果 topK 个。 + :param save_on_exception: 在出异常信息时,是否保存。 + :param larger_better: monitor 的值是否时越大越好。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 + :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 + 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param kwargs: + """ + @property + def save_fn_name(self): + return 'save' @property def callback_name(self): @@ -262,6 +333,8 @@ class CheckpointCallback(Callback): 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; :return: """ - return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}" - + return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + @property + def folder_prefix(self): + return 'trainer' diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index b4ef4e62..e7b94f8c 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback): 请在函数内完成对模型的保存。 :param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, 请在函数内完成对模型的加载。 - :param delete_after_train: 在加载了最佳模型之后是否删掉模型。 + :param delete_after_train: 在训练结束后是否删掉模型。 """ if model_load_fn is not None: assert callable(model_load_fn), "`model_load_fn` must be a callable object." diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 9e1ccfbf..e7aaeea8 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger): self.driver.set_deterministic_dataloader(self.dataloader) self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, - reproducible=self.callback_manager.has_trainer_chechpoint) + reproducible=self.callback_manager.has_trainer_checkpoint) self.set_grad_to_none = kwargs.get("set_grad_to_none", True) self.on_after_trainer_initialized(self.driver) @@ -509,7 +509,7 @@ class Trainer(TrainerEventTrigger): :param folder: 保存模型的地址; :param only_state_dict: 是否只保存模型的 `state_dict`; - :param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; + :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; :param kwargs: 一些 driver 的保存模型的函数的参数另有其它; """ @@ -534,7 +534,16 @@ class Trainer(TrainerEventTrigger): def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False, model_load_fn: Optional[Callable] = None, **kwargs): + """ + 加载模型 + :param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, + 直接将该 folder 传递到 model_load_fn 中。 + :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。 + :param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。 + :param kwargs: + :return: + """ self.on_load_model() self.driver.barrier() if not isinstance(folder, (io.BytesIO, BinaryIO)): @@ -555,7 +564,13 @@ class Trainer(TrainerEventTrigger): def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): r""" - 用于断点重训的保存函数; + 用于断点重训 Trainer 的保存函数; + + :param folder: + :param only_state_dict: + :param model_save_fn: + :param kwargs: + :return: """ self.driver.barrier() diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 4b141761..68ec9128 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -68,9 +68,12 @@ class Driver(ABC): def set_sampler_epoch(self, dataloader, cur_epoch_idx): r""" 对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; + dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。 + :param dataloader: 需要设置 epoch 的 dataloader 。 :param cur_epoch_idx: 当前是第几个 epoch; """ + @abstractmethod def train_step(self, batch): """ diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 96d11761..685e1f63 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -143,8 +143,6 @@ class TorchDriver(Driver): :param filepath: 保存到哪个文件夹; :param only_state_dict: 是否只保存权重; - :param model_save_fn: - :return: """ model = self.unwrap_model() diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 66159f24..0d497bc2 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -44,6 +44,9 @@ __all__ = [ ] + + + def get_fn_arg_names(fn: Callable) -> List[str]: r""" 返回一个函数的所有参数的名字; diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index 38b79b44..0e67cf20 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -65,7 +65,7 @@ def set_env_on_import(): # fastNLP 内部使用的一些变量 if FASTNLP_LAUNCH_TIME not in os.environ: - cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}" + cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}" os.environ[FASTNLP_LAUNCH_TIME] = cur_time # 设置对应的值 diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index f7cc6e5f..1f404bb8 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -8,7 +8,7 @@ import torch.distributed as dist from pathlib import Path import re -from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback +from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback from fastNLP.core.controllers.trainer import Trainer from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK @@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1( version, only_state_dict ): +# def test_model_checkpoint_callback_1( +# model_and_optimizers: TrainerParameters, +# driver='torch_ddp', +# device=[0, 1], +# version=1, +# only_state_dict=True +# ): path = Path.cwd().joinpath(f"test_model_checkpoint") path.mkdir(exist_ok=True, parents=True) if version == 0: callbacks = [ - CheckpointCallback( + ModelCheckpointCallback( monitor="acc", save_folder=path, save_every_n_epochs=1, - save_every_n_global_batches=123, # 避免和 epoch 的保存重复; + save_every_n_batches=123, # 避免和 epoch 的保存重复; save_topk=None, save_last=False, save_on_exception=None, @@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1( ] elif version == 1: callbacks = [ - CheckpointCallback( + ModelCheckpointCallback( monitor="acc", save_folder=path, save_every_n_epochs=3, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=2, save_last=True, save_on_exception=None, @@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1( input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, - n_epochs=10, callbacks=callbacks, output_from_new_proc="all" @@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1( if version == 0: if driver == "torch": - assert "epoch_10-global_batch_250-acc" in all_saved_model_paths - assert "epoch_4-global_batch_123-acc" in all_saved_model_paths + assert "model-epoch_10" in all_saved_model_paths + assert "model-epoch_4-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"] - step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["model-epoch_10"] + step_save_path = all_saved_model_paths["model-epoch_4-batch_123"] assert len(all_saved_model_paths) == 12 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_6-global_batch_78-acc" in all_saved_model_paths - assert "epoch_9-global_batch_123-acc" in all_saved_model_paths + assert "model-epoch_6" in all_saved_model_paths + assert "model-epoch_9-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"] - step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["model-epoch_6"] + step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] assert len(all_saved_model_paths) == 11 all_state_dicts = [epoch_save_path, step_save_path] elif version == 1: - pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") + pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") if driver == "torch": - assert "epoch_9-global_batch_225-acc" in all_saved_model_paths - assert "last" in all_saved_model_paths + assert "model-epoch_9" in all_saved_model_paths + assert "model-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: each_folder_name = pattern.findall(each_folder_name) @@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"] - last_save_path = all_saved_model_paths["last"] + epoch_save_path = all_saved_model_paths["model-epoch_9"] + last_save_path = all_saved_model_paths["model-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 6 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_9-global_batch_117-acc" in all_saved_model_paths - assert "last" in all_saved_model_paths + assert "model-epoch_9" in all_saved_model_paths + assert "model-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"] - last_save_path = all_saved_model_paths["last"] + epoch_save_path = all_saved_model_paths["model-epoch_9"] + last_save_path = all_saved_model_paths["model-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 6 @@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1( finally: synchronize_safe_rm(path) - # pass + pass if dist.is_initialized(): dist.destroy_process_group() @@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2( raise NotImplementedError callbacks = [ - CheckpointCallback( + ModelCheckpointCallback( monitor="acc1", save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=None, save_last=False, save_on_exception=NotImplementedError, @@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2( all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} if driver == "torch": - assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths - exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"] + assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths + exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths - exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"] + assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths + exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] assert len(all_saved_model_paths) == 1 all_state_dicts = [exception_model_path] @@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1( if version == 0: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=7, - save_every_n_global_batches=123, # 避免和 epoch 的保存重复; + save_every_n_batches=123, # 避免和 epoch 的保存重复; save_topk=None, save_last=False, save_on_exception=None, @@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1( ] elif version == 1: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=2, save_last=True, save_on_exception=None, @@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1( if version == 0: if driver == "torch": - assert "epoch_7-global_batch_175-acc" in all_saved_model_paths - assert "epoch_4-global_batch_123-acc" in all_saved_model_paths + assert "trainer-epoch_7" in all_saved_model_paths + assert "trainer-epoch_4-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"] - step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_7"] + step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"] assert len(all_saved_model_paths) == 3 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_7-global_batch_91-acc" in all_saved_model_paths - assert "epoch_9-global_batch_123-acc" in all_saved_model_paths + assert "trainer-epoch_7" in all_saved_model_paths + assert "trainer-epoch_9-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"] - step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_7"] + step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"] assert len(all_saved_model_paths) == 2 all_state_dicts = [epoch_save_path, step_save_path] elif version == 1: - pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") + pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} if driver == "torch": - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: each_folder_name = pattern.findall(each_folder_name) @@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 3 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 3 @@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2( device, version ): + pytest.skip("Skip transformers test for now.") path = Path.cwd().joinpath(f"test_model_checkpoint") path.mkdir(exist_ok=True, parents=True) - import transformers + import transformers # 版本4.16.2 import torch from torchmetrics import Accuracy from transformers import AutoModelForSequenceClassification @@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2( if version == 0: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=50, + save_every_n_batches=50, save_topk=None, save_last=False, save_on_exception=None, @@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2( ] elif version == 1: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=1, save_last=True, save_on_exception=None, @@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2( if version == 0: if driver == "torch": - assert "epoch_1-global_batch_200-acc" in all_saved_model_paths + assert "trainer-epoch_1-batch_200" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] assert len(all_saved_model_paths) == 4 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_1-global_batch_100-acc" in all_saved_model_paths + assert "trainer-epoch_1-batch_100" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"] assert len(all_saved_model_paths) == 2 all_state_dicts = [epoch_save_path] elif version == 1: - pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") + pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} if driver == "torch": - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: each_folder_name = pattern.findall(each_folder_name) @@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 1 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 2 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 1 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 2 From 607367588c70b12c535f82b912760f63d558d316 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 11 Apr 2022 15:21:45 +0800 Subject: [PATCH 2/2] =?UTF-8?q?1.=E5=A2=9E=E5=8A=A0DataSet=E7=9A=84=5F=5Fs?= =?UTF-8?q?etitem=5F=5F=E6=96=B9=E6=B3=95=EF=BC=8C=E4=BD=BF=E5=BE=97?= =?UTF-8?q?=E5=85=B6=E5=8F=AF=E4=BB=A5=E7=9B=B4=E6=8E=A5random.shuffle(dat?= =?UTF-8?q?aset);=202.=E4=BC=98=E5=8C=96=E9=83=A8=E5=88=86log=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/__init__.py | 5 +++-- fastNLP/core/callbacks/checkpoint_callback.py | 3 ++- fastNLP/core/controllers/evaluator.py | 18 ++++++++--------- fastNLP/core/controllers/trainer.py | 1 + fastNLP/core/dataset/dataset.py | 20 ++++++++++++++----- fastNLP/core/dataset/field.py | 5 +---- fastNLP/core/drivers/driver.py | 14 +++++++------ .../torch_driver/initialize_torch_driver.py | 4 ++-- tests/core/dataset/test_dataset.py | 14 +++++++++++++ 9 files changed, 55 insertions(+), 29 deletions(-) diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index f45cf5e0..a47ab998 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -4,7 +4,8 @@ __all__ = [ 'EventsList', 'Filter', 'CallbackManager', - 'CheckpointCallback', + 'ModelCheckpointCallback', + 'TrainerCheckpointCallback', 'choose_progress_callback', 'ProgressCallback', 'RichCallback', @@ -16,7 +17,7 @@ __all__ = [ from .callback import Callback from .callback_events import EventsList, Events, Filter from .callback_manager import CallbackManager -from .checkpoint_callback import CheckpointCallback +from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 5cd102e0..d3a3b52d 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -1,5 +1,6 @@ __all__ = [ - 'CheckpointCallback' + 'ModelCheckpointCallback', + 'TrainerCheckpointCallback' ] import os from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index f58a7faf..bd66d0a0 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -133,17 +133,18 @@ class Evaluator: self.driver.barrier() - def run(self, num_eval_batch_per_dl: int = -1) -> Dict: + def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: """ 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 - 如果存在多个metric,一个dataloader的情况,key的命名规则是 - metric_indicator_name#metric_name - 如果存在多个数据集,一个metric的情况,key的命名规则是 - metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 - 如果存在多个metric,多个dataloader的情况,key的命名规则是 - metric_indicator_name#metric_name#dataloader_name - :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 + 如果存在多个metric,一个dataloader的情况,key的命名规则是 + metric_indicator_name#metric_name + 如果存在多个数据集,一个metric的情况,key的命名规则是 + metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 + 如果存在多个metric,多个dataloader的情况,key的命名规则是 + metric_indicator_name#metric_name#dataloader_name + 其中 metric_indicator_name 可能不存在。 + :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 :return: """ assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." @@ -157,7 +158,6 @@ class Evaluator: assert self.driver.has_test_dataloaders() metric_results = {} - self.reset() evaluate_context = self.driver.get_evaluate_context() self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e7aaeea8..b7456b61 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger): raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") if self.evaluator is not None and num_eval_sanity_batch > 0: + logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") self.on_sanity_check_begin() sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) self.on_sanity_check_end(sanity_check_res) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 037fde00..9630a3a0 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -8,9 +8,8 @@ __all__ = [ import _pickle as pickle from copy import deepcopy -from typing import Optional, List, Callable, Union, Dict, Any +from typing import Optional, List, Callable, Union, Dict, Any, Mapping from functools import partial -import warnings import numpy as np from threading import Thread @@ -197,6 +196,20 @@ class DataSet: else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) + def __setitem__(self, key, value): + assert isinstance(key, int) and key