Browse Source

修复TopkSaver在CheckpointCallbackk中的bug;修改Saver对象,使得其方便被直接使用

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
1ecbdc7446
2 changed files with 6 additions and 5 deletions
  1. +4
    -3
      fastNLP/core/callbacks/checkpoint_callback.py
  2. +2
    -2
      fastNLP/core/callbacks/topk_saver.py

+ 4
- 3
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -91,9 +91,10 @@ class CheckpointCallback(Callback):
else:
on_exceptions = []

self.topk_saver = TopkSaver(topk, monitor, larger_better, folder, only_state_dict,
model_save_fn, save_evaluate_results,
save_object, **kwargs)
self.topk_saver = TopkSaver(topk=topk, monitor=monitor, larger_better=larger_better, folder=folder,
save_object=save_object, only_state_dict=only_state_dict, model_save_fn=model_save_fn,
save_evaluate_results=save_evaluate_results, **kwargs)

self.topk = topk
self.save_object = save_object



+ 2
- 2
fastNLP/core/callbacks/topk_saver.py View File

@@ -169,11 +169,11 @@ class TopkQueue:


class TopkSaver(MonitorUtility, Saver):
def __init__(self, topk:int, monitor:str, larger_better:bool=True, folder:str=None, save_object:str='model',
def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model',
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True,
**kwargs):
"""
用来保存识别 topk 模型并保存,也可以仅当一个保存 saver 使用。保存路径为
用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为
- folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名


Loading…
Cancel
Save