|
|
@@ -19,7 +19,7 @@ class CheckpointCallback(Callback): |
|
|
|
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', |
|
|
|
save_evaluate_results=True, **kwargs): |
|
|
|
""" |
|
|
|
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下:: |
|
|
|
保存 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下:: |
|
|
|
|
|
|
|
- folder/ |
|
|
|
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 |
|
|
@@ -29,8 +29,9 @@ class CheckpointCallback(Callback): |
|
|
|
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 |
|
|
|
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 |
|
|
|
|
|
|
|
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 |
|
|
|
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 |
|
|
|
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 model_save_fn 不为 None, |
|
|
|
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model |
|
|
|
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 |
|
|
|
|
|
|
|
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 |
|
|
|
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 |
|
|
@@ -46,22 +47,14 @@ class CheckpointCallback(Callback): |
|
|
|
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 |
|
|
|
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 |
|
|
|
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 |
|
|
|
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 |
|
|
|
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 |
|
|
|
保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load` 加载该断 |
|
|
|
点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 |
|
|
|
:param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个 |
|
|
|
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 |
|
|
|
:param kwargs: |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
if folder is None: |
|
|
|
logger.warning( |
|
|
|
"Parameter `folder` is None, and we will use the current work directory to find and load your model.") |
|
|
|
folder = Path.cwd() |
|
|
|
folder = Path(folder) |
|
|
|
if not folder.exists(): |
|
|
|
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") |
|
|
|
elif folder.is_file(): |
|
|
|
raise ValueError("Parameter `folder` should be a directory instead of a file.") |
|
|
|
|
|
|
|
if every_n_epochs is not None: |
|
|
|
if not isinstance(every_n_epochs, int) or every_n_epochs < 1: |
|
|
|
raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.") |
|
|
@@ -74,12 +67,6 @@ class CheckpointCallback(Callback): |
|
|
|
else: |
|
|
|
every_n_batches = sys.maxsize # 使得没有数字可以整除 |
|
|
|
|
|
|
|
if topk is not None: |
|
|
|
if not isinstance(topk, int): |
|
|
|
raise ValueError("Parameter `topk` should be an int.") |
|
|
|
else: |
|
|
|
topk = 0 |
|
|
|
|
|
|
|
if on_exceptions is not None: |
|
|
|
if not isinstance(on_exceptions, Sequence): |
|
|
|
on_exceptions = [on_exceptions] |
|
|
|