From 1ecbdc7446fa0041b725f67e6fa111ddddd83ad5 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 23 Apr 2022 00:41:12 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DTopkSaver=E5=9C=A8CheckpointC?= =?UTF-8?q?allbackk=E4=B8=AD=E7=9A=84bug;=E4=BF=AE=E6=94=B9Saver=E5=AF=B9?= =?UTF-8?q?=E8=B1=A1=EF=BC=8C=E4=BD=BF=E5=BE=97=E5=85=B6=E6=96=B9=E4=BE=BF?= =?UTF-8?q?=E8=A2=AB=E7=9B=B4=E6=8E=A5=E4=BD=BF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/checkpoint_callback.py | 7 ++++--- fastNLP/core/callbacks/topk_saver.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 0f6dcd6a..e12873d3 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -91,9 +91,10 @@ class CheckpointCallback(Callback): else: on_exceptions = [] - self.topk_saver = TopkSaver(topk, monitor, larger_better, folder, only_state_dict, - model_save_fn, save_evaluate_results, - save_object, **kwargs) + self.topk_saver = TopkSaver(topk=topk, monitor=monitor, larger_better=larger_better, folder=folder, + save_object=save_object, only_state_dict=only_state_dict, model_save_fn=model_save_fn, + save_evaluate_results=save_evaluate_results, **kwargs) + self.topk = topk self.save_object = save_object diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index d541c926..8c3f3811 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -169,11 +169,11 @@ class TopkQueue: class TopkSaver(MonitorUtility, Saver): - def __init__(self, topk:int, monitor:str, larger_better:bool=True, folder:str=None, save_object:str='model', + def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, **kwargs): """ - 用来保存识别 topk 模型并保存,也可以仅当一个保存 saver 使用。保存路径为 + 用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为 - folder/ - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名