@@ -10,7 +10,8 @@ __all__ = [ | |||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
"LRSchedCallback", | "LRSchedCallback", | ||||
'LoadBestModelCallback' | |||||
'LoadBestModelCallback', | |||||
"EarlyStopCallback" | |||||
] | ] | ||||
@@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb | |||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
from .lr_scheduler_callback import LRSchedCallback | from .lr_scheduler_callback import LRSchedCallback | ||||
from .load_best_model_callback import LoadBestModelCallback | from .load_best_model_callback import LoadBestModelCallback | ||||
from .early_stop_callback import EarlyStopCallback | |||||
@@ -1,11 +1,15 @@ | |||||
from typing import Union, Callable, Dict, Optional | |||||
from typing import Union, Callable, Dict, Optional, Any | |||||
from abc import ABC | |||||
__all__ = [ | __all__ = [ | ||||
'Callback', | 'Callback', | ||||
] | ] | ||||
from .callback_events import Events, EventsList, Filter | from .callback_events import Events, EventsList, Filter | ||||
from .utils import _get_monitor_value | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.utils import apply_to_collection | |||||
class Callback: | class Callback: | ||||
@@ -150,4 +154,82 @@ class _CallbackWrapper(Callback): | |||||
return self.fn.__name__ | return self.fn.__name__ | ||||
class CanItemDataType(ABC): | |||||
""" | |||||
检测可以进行传输的对象。 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is CanItemDataType: | |||||
item = getattr(subclass, 'item', None) | |||||
return callable(item) | |||||
return NotImplemented | |||||
class HasMonitorCallback(Callback): | |||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
self.set_monitor(monitor, larger_better) | |||||
self.must_have_moinitor = must_have_monitor | |||||
def set_monitor(self, monitor, larger_better): | |||||
self.monitor = str(monitor) if monitor is not None else None | |||||
self.larger_better = bool(larger_better) | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = self.monitor | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
""" | |||||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
:param trainer: | |||||
:param driver: | |||||
:return: | |||||
""" | |||||
if self.monitor is None and trainer.monitor is not None: | |||||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
if self.must_have_moinitor and self.monitor is None: | |||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
f"You can set it in the initialization or through Trainer.") | |||||
def get_monitor_value(self, results:Dict)->float: | |||||
""" | |||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||||
:param results: | |||||
:return: | |||||
""" | |||||
if len(results)==0: | |||||
return 0 | |||||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | |||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||||
logger.warning( | |||||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
self._real_monitor = use_monitor | |||||
return monitor_value | |||||
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): | |||||
""" | |||||
检测 monitor_value 是否是更好的 | |||||
:param monitor_value: | |||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||||
:return: | |||||
""" | |||||
better = False | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
better = True | |||||
if keep_if_better: | |||||
self.monitor_value = monitor_value | |||||
return better |
@@ -5,12 +5,12 @@ __all__ = [ | |||||
import os | import os | ||||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | ||||
from pathlib import Path | from pathlib import Path | ||||
from abc import ABC | |||||
import sys | import sys | ||||
from copy import deepcopy | |||||
import fastNLP | import fastNLP | ||||
from .callback import Callback, Filter | |||||
from .callback import Callback, HasMonitorCallback | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | from fastNLP.core.callbacks.utils import _get_monitor_value | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
@@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
class CanItemDataType(ABC): | |||||
""" | |||||
检测可以进行传输的对象。 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is CanItemDataType: | |||||
item = getattr(subclass, 'item', None) | |||||
return callable(item) | |||||
return NotImplemented | |||||
class CheckpointCallback(Callback): | |||||
class CheckpointCallback(HasMonitorCallback): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
monitor, | monitor, | ||||
@@ -48,13 +33,8 @@ class CheckpointCallback(Callback): | |||||
model_save_fn: Optional[Callable] = None, | model_save_fn: Optional[Callable] = None, | ||||
**kwargs, | **kwargs, | ||||
): | ): | ||||
# 我们新加了逻辑,如果 checkpoint callback 自己没有设置 monitor 和 larger_better,那么我们会将其在 trainer 中的设置赋值给它们; | |||||
# if monitor is None and save_topk is not None: | |||||
# raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") | |||||
if monitor is not None and not isinstance(monitor, str): | |||||
raise ValueError("Parameter `monitor` should be of 'str' type.") | |||||
super().__init__(monitor=monitor, larger_better=larger_better, | |||||
must_have_monitor=save_topk is not None) | |||||
if save_folder is None: | if save_folder is None: | ||||
logger.warning( | logger.warning( | ||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | "Parameter `path` is None, and we will use the current work directory to find and load your model.") | ||||
@@ -92,13 +72,12 @@ class CheckpointCallback(Callback): | |||||
"`BaseException` type.") | "`BaseException` type.") | ||||
else: | else: | ||||
save_on_exception = [] | save_on_exception = [] | ||||
self.monitor = monitor | |||||
self.save_folder = Path(save_folder) | self.save_folder = Path(save_folder) | ||||
self.save_every_n_epochs = save_every_n_epochs | self.save_every_n_epochs = save_every_n_epochs | ||||
self.save_every_n_batches = save_every_n_batches | self.save_every_n_batches = save_every_n_batches | ||||
self.save_last = save_last | self.save_last = save_last | ||||
self.save_topk = save_topk | self.save_topk = save_topk | ||||
self.larger_better = larger_better | |||||
self.only_state_dict = only_state_dict | self.only_state_dict = only_state_dict | ||||
self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
self.save_on_exception = save_on_exception | self.save_on_exception = save_on_exception | ||||
@@ -108,12 +87,6 @@ class CheckpointCallback(Callback): | |||||
self._topk_model = {} | self._topk_model = {} | ||||
self._topn = 0 # 表示目前已经保存了几个最好的模型; | self._topn = 0 # 表示目前已经保存了几个最好的模型; | ||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 | |||||
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 | |||||
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; | |||||
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; | |||||
self._real_monitor = self.monitor | |||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | ||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | ||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | ||||
@@ -121,20 +94,15 @@ class CheckpointCallback(Callback): | |||||
synchronize_mkdir(self.timestamp_path) | synchronize_mkdir(self.timestamp_path) | ||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if self.monitor is None: | |||||
if trainer.monitor is not None: | |||||
self.monitor = trainer.monitor | |||||
self.larger_better = trainer.larger_better | |||||
elif self.save_topk is not None: | |||||
raise RuntimeError("You are using `topk` mode, but you have not set the `monitor` value either in this" | |||||
"callback or in trainer.") | |||||
else: | |||||
self.monitor = None | |||||
if self.save_topk is not None: | |||||
super().on_after_trainer_initialized(trainer, driver) | |||||
if self.save_topk is not None and trainer.evaluator is None: | if self.save_topk is not None and trainer.evaluator is None: | ||||
raise RuntimeError("You are using `topk` mode, but there is no `evaluator` in trainer.") | |||||
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") | |||||
def on_validate_end(self, trainer, validate_res): | |||||
self._save_topk(trainer, validate_res) | |||||
def on_validate_end(self, trainer, results): | |||||
if len(results) == 0: | |||||
return | |||||
self._save_topk(trainer, results) | |||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | ||||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | ||||
@@ -157,7 +125,7 @@ class CheckpointCallback(Callback): | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
# 主要核对一下 monitor 是否存在。 | # 主要核对一下 monitor 是否存在。 | ||||
self._get_validate_metric(sanity_check_res) | |||||
self.get_monitor_value(results=sanity_check_res) | |||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | """ | ||||
@@ -168,8 +136,7 @@ class CheckpointCallback(Callback): | |||||
states = {} | states = {} | ||||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | states['timestamp_path'] = str(self.timestamp_path.absolute()) | ||||
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, | |||||
function=lambda x:x.item()) | |||||
states['_topk_model'] = deepcopy(self._topk_model) | |||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | states['save_topk'] = 0 if self.save_topk is None else self.save_topk | ||||
states['_real_monitor'] = self._real_monitor | states['_real_monitor'] = self._real_monitor | ||||
return states | return states | ||||
@@ -190,30 +157,30 @@ class CheckpointCallback(Callback): | |||||
self._topk_model.update(self._topk_model) | self._topk_model.update(self._topk_model) | ||||
self._real_monitor = states["real_monitor"] | self._real_monitor = states["real_monitor"] | ||||
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): | |||||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||||
""" | """ | ||||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | ||||
:param trainer: | :param trainer: | ||||
:param validate_res: | |||||
:param results: | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.save_topk is not None: | if self.save_topk is not None: | ||||
_metric_value = self._get_validate_metric(validate_res) | |||||
monitor_value = self.get_monitor_value(results=results) | |||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | ||||
f"-{self._real_monitor}_{_metric_value}" | |||||
f"-{self._real_monitor}_{monitor_value}" | |||||
_should_save = False | _should_save = False | ||||
if self._topn < self.save_topk: | if self._topn < self.save_topk: | ||||
self._topk_model[folder_name] = _metric_value | |||||
self._topk_model[folder_name] = monitor_value | |||||
self._topn += 1 | self._topn += 1 | ||||
_should_save = True | _should_save = True | ||||
else: | else: | ||||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | _least_valuable_model = (min if self.larger_better else max)(self._topk_model, | ||||
key=lambda x: self._topk_model[x]) | key=lambda x: self._topk_model[x]) | ||||
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ | |||||
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): | |||||
self._topk_model[folder_name] = _metric_value | |||||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||||
self._topk_model[folder_name] = monitor_value | |||||
_should_save = True | _should_save = True | ||||
self._topk_model.pop(_least_valuable_model) | self._topk_model.pop(_least_valuable_model) | ||||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | ||||
@@ -249,7 +216,11 @@ class CheckpointCallback(Callback): | |||||
:return: | :return: | ||||
""" | """ | ||||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | ||||
if self._real_monitor != use_monitor: | |||||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
self._real_monitor = use_monitor | self._real_monitor = use_monitor | ||||
return value | return value | ||||
@property | @property | ||||
@@ -277,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -324,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -0,0 +1,61 @@ | |||||
__all__ = [ | |||||
'EarlyStopCallback' | |||||
] | |||||
from typing import Dict | |||||
from .callback import HasMonitorCallback | |||||
from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class EarlyStopCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||||
""" | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | |||||
:param patience: 多少次 validate 不没有提升就停止。 | |||||
""" | |||||
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | |||||
self.wait = 0 | |||||
self.patience = patience | |||||
def on_validate_end(self, trainer, results): | |||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
self.wait = 0 | |||||
else: | |||||
self.wait += 1 | |||||
def on_fetch_data_begin(self, trainer): | |||||
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
if self.wait >= self.patience: | |||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | |||||
f"metric `{self._real_monitor}`") | |||||
def on_train_epoch_begin(self, trainer): | |||||
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
if self.wait >= self.patience: | |||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | |||||
f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | |||||
def on_save_checkpoint(self, trainer) -> Dict: | |||||
states = { | |||||
'patience': self.patience, | |||||
'wait': self.wait, | |||||
'monitor': self.monitor, | |||||
'monitor_value': self.monitor_value | |||||
} | |||||
return states | |||||
def on_load_checkpoint(self, trainer, states): | |||||
self.patience = states['patience'] | |||||
self.wait = states['wait'] | |||||
self.monitor = states['monitor'] | |||||
self.monitor_value = float(states['monitor_value']) | |||||
def callback_name(self): | |||||
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' | |||||
@@ -4,8 +4,7 @@ __all__ = [ | |||||
import os | import os | ||||
from typing import Optional, Callable | from typing import Optional, Callable | ||||
from .callback import Callback | |||||
from .utils import _get_monitor_value | |||||
from .callback import HasMonitorCallback | |||||
from io import BytesIO | from io import BytesIO | ||||
import shutil | import shutil | ||||
@@ -14,15 +13,15 @@ from fastNLP.core.log import logger | |||||
from fastNLP.envs import all_rank_call | from fastNLP.envs import all_rank_call | ||||
class LoadBestModelCallback(Callback): | |||||
def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True, | |||||
class LoadBestModelCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | ||||
model_load_fn:Optional[Callable] = None, | model_load_fn:Optional[Callable] = None, | ||||
delete_after_train:bool = True): | delete_after_train:bool = True): | ||||
""" | """ | ||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | ||||
:param str monitor: 监控的 metric 值。 | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback): | |||||
请在函数内完成对模型的加载。 | 请在函数内完成对模型的加载。 | ||||
:param delete_after_train: 在训练结束后是否删掉模型。 | :param delete_after_train: 在训练结束后是否删掉模型。 | ||||
""" | """ | ||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | |||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
assert callable(model_load_fn), "`model_load_fn` must be a callable object." | assert callable(model_load_fn), "`model_load_fn` must be a callable object." | ||||
assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." | assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." | ||||
@@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback): | |||||
self.real_save_folder = None | self.real_save_folder = None | ||||
self.buffer = BytesIO() | self.buffer = BytesIO() | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
self.save_folder = save_folder | self.save_folder = save_folder | ||||
self.only_state_dict = only_state_dict | self.only_state_dict = only_state_dict | ||||
self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
self.model_load_fn = model_load_fn | self.model_load_fn = model_load_fn | ||||
self.delete_after_after = delete_after_train | self.delete_after_after = delete_after_train | ||||
self._real_monitor = None | |||||
self.monitor_value = float('-inf') if larger_better else float('inf') | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | ||||
@@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback): | |||||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | ||||
f"save best model when launch using script.") | f"save best model when launch using script.") | ||||
super().on_after_trainer_initialized(trainer, driver) | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
self.get_monitor_value(sanity_check_res) | |||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (monitor_value < self.monitor_value and self.larger_better is False) or \ | |||||
(monitor_value > self.monitor_value and self.larger_better): | |||||
self.monitor_value = monitor_value | |||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
model_save_fn=self.model_save_fn) | model_save_fn=self.model_save_fn) | ||||
@@ -8,7 +8,7 @@ __all__ = [ | |||||
'RichCallback' | 'RichCallback' | ||||
] | ] | ||||
from .callback import Callback | |||||
from .callback import HasMonitorCallback | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | from fastNLP.core.callbacks.utils import _get_monitor_value | ||||
from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str): | |||||
return None | return None | ||||
class ProgressCallback(Callback): | |||||
class ProgressCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
f_rich_progress.stop() | f_rich_progress.stop() | ||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=sanity_check_res) | |||||
self.get_monitor_value(sanity_check_res) | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
@@ -46,28 +44,22 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
super().__init__() | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||||
self.print_every = print_every | self.print_every = print_every | ||||
self.progress_bar = f_rich_progress | self.progress_bar = f_rich_progress | ||||
self.task2id = {} | self.task2id = {} | ||||
self.loss = 0 | self.loss = 0 | ||||
self.loss_round_ndigit = loss_round_ndigit | self.loss_round_ndigit = loss_round_ndigit | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = monitor | |||||
self.format_json = format_json | self.format_json = format_json | ||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if not self.progress_bar.disable: | if not self.progress_bar.disable: | ||||
self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) | self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) | ||||
super(RichCallback, self).on_after_trainer_initialized(trainer, driver) | |||||
def on_train_begin(self, trainer): | def on_train_begin(self, trainer): | ||||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | ||||
@@ -109,16 +101,12 @@ class RichCallback(ProgressCallback): | |||||
text_style = '' | text_style = '' | ||||
characters = '-' | characters = '-' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
rule_style = 'spring_green3' | rule_style = 'spring_green3' | ||||
text_style = '[bold]' | text_style = '[bold]' | ||||
characters = '+' | characters = '+' | ||||
self.monitor_value = monitor_value | |||||
self.progress_bar.print() | self.progress_bar.print() | ||||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | ||||
f"Batch:{trainer.batch_idx_in_epoch}", | f"Batch:{trainer.batch_idx_in_epoch}", | ||||
@@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback): | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
super().__init__() | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||||
self.print_every = print_every | self.print_every = print_every | ||||
self.task2id = {} | self.task2id = {} | ||||
self.loss = 0 | self.loss = 0 | ||||
self.loss_round_ndigit = loss_round_ndigit | self.loss_round_ndigit = loss_round_ndigit | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = monitor | |||||
self.set_monitor(monitor, larger_better) | |||||
self.format_json = format_json | self.format_json = format_json | ||||
self.num_signs = 10 | self.num_signs = 10 | ||||
@@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback): | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
self.monitor_value = monitor_value | |||||
if len(text) == 0: | if len(text) == 0: | ||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | text = '-'*self.num_signs + base_text + '-'*self.num_signs | ||||
@@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||||
if monitor in res: | if monitor in res: | ||||
return monitor, res[monitor] | return monitor, res[monitor] | ||||
if real_monitor in res: | |||||
return real_monitor, res[real_monitor] | |||||
pairs = [] | pairs = [] | ||||
for idx, (key, value) in enumerate(res.items()): | for idx, (key, value) in enumerate(res.items()): | ||||
match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor)) | |||||
pairs.append((key, value, match.size, idx)) | |||||
match_size = _match_length(monitor, key) | |||||
pairs.append((key, value, match_size, idx)) | |||||
pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) | pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) | ||||
key, value, match_size = pairs[0][:3] | key, value, match_size = pairs[0][:3] | ||||
if real_monitor is not None and real_monitor in res and real_monitor != key: | |||||
# 如果 real_monitor 比新找的更长就继续用之前的。 | |||||
match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor)) | |||||
if match.size > match_size: | |||||
return real_monitor, res[real_monitor] | |||||
return key, value | |||||
logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||||
f"we use the `{key}` as the monitor.") | |||||
real_monitor = key | |||||
return real_monitor, value | |||||
def _match_length(a:str, b:str)->int: | |||||
""" | |||||
需要把长度短的放在前面 | |||||
:param a: | |||||
:param b: | |||||
:return: | |||||
""" | |||||
short = a if len(a) < len(b) else b | |||||
long = a if len(a)>=len(b) else b | |||||
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | |||||
return match.size | |||||
@@ -25,6 +25,7 @@ from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, matc | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class Trainer(TrainerEventTrigger): | class Trainer(TrainerEventTrigger): | ||||
@@ -49,6 +50,8 @@ class Trainer(TrainerEventTrigger): | |||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
accumulation_steps: int = 1, | accumulation_steps: int = 1, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
monitor: str = None, | |||||
larger_better: bool = True, | |||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -102,6 +105,10 @@ class Trainer(TrainerEventTrigger): | |||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | ||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
:param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | |||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
:param kwargs: 一些其它的可能需要的参数; | :param kwargs: 一些其它的可能需要的参数; | ||||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | ||||
@@ -210,6 +217,8 @@ class Trainer(TrainerEventTrigger): | |||||
self.evaluator = None | self.evaluator = None | ||||
self.epoch_validate = lambda *args, **kwargs: ... | self.epoch_validate = lambda *args, **kwargs: ... | ||||
self.step_validate = lambda *args, **kwargs: ... | self.step_validate = lambda *args, **kwargs: ... | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if metrics is not None and validate_dataloaders is not None: | if metrics is not None and validate_dataloaders is not None: | ||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | ||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
@@ -239,6 +248,7 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
# validate_every > 0 | # validate_every > 0 | ||||
self._step_validate_filter = Filter(every=validate_every) | self._step_validate_filter = Filter(every=validate_every) | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.validate_every = validate_every | self.validate_every = validate_every | ||||
@@ -320,6 +330,10 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.barrier() | self.driver.barrier() | ||||
self.on_train_end() | self.on_train_end() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
except EarlyStopException as e: | |||||
logger.info(f"Catch early stop exception: {e.msg}.") | |||||
self.on_exception(e) | |||||
except KeyboardInterrupt as e: | except KeyboardInterrupt as e: | ||||
self.driver.on_exception() | self.driver.on_exception() | ||||
self.on_exception(e) | self.on_exception(e) | ||||
@@ -0,0 +1,10 @@ | |||||
class EarlyStopException(BaseException): | |||||
r""" | |||||
用于EarlyStop时从Trainer训练循环中跳出。 | |||||
""" | |||||
def __init__(self, msg): | |||||
super(EarlyStopException, self).__init__(msg) | |||||
self.msg = msg |
@@ -12,32 +12,27 @@ def test_get_monitor_value(): | |||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | ||||
assert monitor == 'f1' and value==0.2 | assert monitor == 'f1' and value==0.2 | ||||
assert 'We can not find' not in output[0] | |||||
# 测试可以匹配,且选择更靠前的 | # 测试可以匹配,且选择更靠前的 | ||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | ||||
assert monitor=='acc#f1' and value==0.2 | assert monitor=='acc#f1' and value==0.2 | ||||
assert 'We can not find' in output[0] | |||||
# 测试monitor匹配不上,使用real_monitor | # 测试monitor匹配不上,使用real_monitor | ||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res) | |||||
monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res) | |||||
assert monitor=='acc#rec' and value==0.3 | assert monitor=='acc#rec' and value==0.3 | ||||
assert 'We can not find' not in output[0] | |||||
# 测试monitor/real_monitor匹配不上, 重新选择 | # 测试monitor/real_monitor匹配不上, 重新选择 | ||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) | monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) | ||||
assert monitor=='acc#f1' and value==0.2 | assert monitor=='acc#f1' and value==0.2 | ||||
assert 'We can not find' in output[0] | |||||
# 测试partial的位置 | # 测试partial的位置 | ||||
res = {"acc#acc": 0.52, "loss#loss": 2} | res = {"acc#acc": 0.52, "loss#loss": 2} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) | monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) | ||||
assert monitor=='loss#loss' and value==2 | assert monitor=='loss#loss' and value==2 | ||||
assert 'We can not find' in output[0] |