From 57caf1d028e5302db7f2e00d1fb7d156ea370c9c Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 12 Apr 2022 11:43:26 +0800 Subject: [PATCH] =?UTF-8?q?checkpoint=20callback=20=E5=8A=A0=E5=85=A5?= =?UTF-8?q?=E4=BA=86=20on=5Fafter=5Ftrainer=5Finitialized=20=E7=9A=84?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/checkpoint_callback.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index d3a3b52d..12b6a9e6 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -48,8 +48,9 @@ class CheckpointCallback(Callback): model_save_fn: Optional[Callable] = None, **kwargs, ): - if monitor is None and save_topk is not None: - raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") + # 我们新加了逻辑,如果 checkpoint callback 自己没有设置 monitor 和 larger_better,那么我们会将其在 trainer 中的设置赋值给它们; + # if monitor is None and save_topk is not None: + # raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") if monitor is not None and not isinstance(monitor, str): raise ValueError("Parameter `monitor` should be of 'str' type.") @@ -119,6 +120,19 @@ class CheckpointCallback(Callback): # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; synchronize_mkdir(self.timestamp_path) + def on_after_trainer_initialized(self, trainer, driver): + if self.monitor is None: + if trainer.monitor is not None: + self.monitor = trainer.monitor + self.larger_better = trainer.larger_better + elif self.save_topk is not None: + raise RuntimeError("You are using `topk` mode, but you have not set the `monitor` value either in this" + "callback or in trainer.") + else: + self.monitor = None + if self.save_topk is not None and trainer.evaluator is None: + raise RuntimeError("You are using `topk` mode, but there is no `evaluator` in trainer.") + def on_validate_end(self, trainer, validate_res): self._save_topk(trainer, validate_res)