|
|
@@ -9,12 +9,13 @@ import sys |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from .topk_saver import TopkSaver |
|
|
|
from .callback import Callback |
|
|
|
from ..utils.exceptions import EarlyStopException |
|
|
|
|
|
|
|
|
|
|
|
class CheckpointCallback(Callback): |
|
|
|
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, |
|
|
|
every_n_batches: Optional[int] = None, last: bool = False, |
|
|
|
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0, |
|
|
|
every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, |
|
|
|
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = [EarlyStopException], |
|
|
|
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, |
|
|
|
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', |
|
|
|
save_evaluate_results=True, **kwargs): |
|
|
@@ -49,7 +50,7 @@ class CheckpointCallback(Callback): |
|
|
|
:param every_n_batches: 多少个 batch 保存一次。 |
|
|
|
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 |
|
|
|
:param topk: 保存 monitor 结果 topK 个。 |
|
|
|
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。 |
|
|
|
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 EarlyStopException 。 |
|
|
|
:param larger_better: monitor 的值是否时越大越好。 |
|
|
|
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 |
|
|
|
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 |
|
|
|