From f74b9b6bec391a7dac82b43970ca237b0bbaec8b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 28 Apr 2022 16:27:17 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=89=80=E6=9C=89=E7=9A=84?= =?UTF-8?q?=20validate=20=E4=B8=BA=20evaluate=20;=20=E7=A7=BB=E5=8A=A8=20c?= =?UTF-8?q?allback.on=5Ftrain=5Fend()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 36 ++++++++++++++++--- fastNLP/core/callbacks/callback_events.py | 4 +-- fastNLP/core/callbacks/callback_manager.py | 4 +-- fastNLP/core/callbacks/checkpoint_callback.py | 2 +- fastNLP/core/callbacks/early_stop_callback.py | 8 ++--- .../core/callbacks/has_monitor_callback.py | 2 +- .../callbacks/load_best_model_callback.py | 25 ++++--------- .../core/callbacks/more_evaluate_callback.py | 10 +++--- fastNLP/core/callbacks/progress_callback.py | 7 ++-- .../controllers/loops/train_batch_loop.py | 2 +- fastNLP/core/controllers/trainer.py | 31 ++++++++-------- fastNLP/core/controllers/utils/utils.py | 16 ++++----- .../drivers/jittor_driver/jittor_driver.py | 2 +- fastNLP/core/log/logger.py | 13 +++++++ tests/helpers/callbacks/helper_callbacks.py | 10 +++--- 15 files changed, 99 insertions(+), 73 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 1d3d1f11..982df7da 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -12,6 +12,34 @@ from fastNLP.core.callbacks.callback_events import _SingleEventState class Callback: r""" 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; + callback 调用时机顺序大概如下 + Trainer.__init__(): + on_after_trainer_initialized() + Trainer.run(): + if num_eval_sanity_batch>0: + on_sanity_check_begin() # 如果设置了num_eval_sanity_batch + on_sanity_check_end() + try: + on_train_begin() + while cur_epoch_idx < n_epochs: + on_train_epoch_begin() + while batch_idx_in_epoch<=num_batches_per_epoch: + on_fetch_data_begin() + on_fetch_data_end() + on_train_batch_begin() + on_before_backward() + on_after_backward() + on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 + on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 + on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 + on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 + on_train_batch_end() + on_train_epoch_end() + except BaseException: + self.on_exception() + finally: + on_train_end() + 其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 """ def on_after_trainer_initialized(self, trainer, driver): @@ -221,9 +249,9 @@ class Callback: """ pass - def on_validate_begin(self, trainer): + def on_evaluate_begin(self, trainer): """ - 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 + 在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 :param trainer: @@ -231,9 +259,9 @@ class Callback: """ pass - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): """ - 结束 validate 时调用,并把 validate 的结果传入。 + 结束 evaluate 时调用,并把 evaluate 的结果传入。 :param trainer: :param results: Evaluate 的结果,一般是个 dict 。 diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index ef972b35..3f3691e3 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -96,8 +96,8 @@ class Events(EventEnum): on_after_optimizers_step = "on_after_optimizers_step" on_before_zero_grad = "on_before_zero_grad" on_after_zero_grad = "on_after_zero_grad" - on_validate_begin = "on_validate_begin" - on_validate_end = "on_validate_end" + on_evaluate_begin = "on_evaluate_begin" + on_evaluate_end = "on_evaluate_end" class EventsList: diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index c5b00e71..90d2e1b1 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -281,9 +281,9 @@ class CallbackManager: pass @_transfer - def on_validate_begin(self, trainer): + def on_evaluate_begin(self, trainer): pass @_transfer - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): pass diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index e12873d3..0f4ed04d 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -114,7 +114,7 @@ class CheckpointCallback(Callback): if self.topk_saver.topk_queue and trainer.evaluator is None: logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): # 如果发生了保存,则返回的 folder 不为 None folder = self.topk_saver.save_topk(trainer, results) diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index 0923eb00..1e867866 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -16,13 +16,13 @@ class EarlyStopCallback(HasMonitorCallback): 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 - :param patience: 多少次 validate 不没有提升就停止。 + :param patience: 多少次 evaluate 不没有提升就停止。 """ super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) self.wait = 0 self.patience = patience - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): monitor_value = self.get_monitor_value(results) if monitor_value is None: return @@ -32,13 +32,13 @@ class EarlyStopCallback(HasMonitorCallback): self.wait += 1 def on_fetch_data_begin(self, trainer): - # 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 + # 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 if self.wait >= self.patience: raise EarlyStopException(f"After {self.wait} validations, no improvement for " f"metric `{self._real_monitor}`") def on_train_epoch_begin(self, trainer): - # 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 + # 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 if self.wait >= self.patience: raise EarlyStopException(f"After {self.wait} validations, no improvement for " f"metric `{self._real_monitor}`(best value: {self.monitor_value})") diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index b13f9dd6..52214ff0 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -216,6 +216,6 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') self.execute_fn = execute_fn - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if self.is_better_results(results): self.execute_fn() \ No newline at end of file diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 91bdb084..5addd2e2 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -76,7 +76,7 @@ class LoadBestModelCallback(HasMonitorCallback): super().on_after_trainer_initialized(trainer, driver) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if self.is_better_results(results, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, @@ -95,27 +95,14 @@ class LoadBestModelCallback(HasMonitorCallback): self.buffer.seek(0) trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) - trainer.driver.barrier() + self._delete_after_after(trainer) + def _delete_after_after(self, trainer): + trainer.driver.barrier() if self.delete_after_after: - if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - # 只需要 rank 0 执行删除。 - logger.info(f"Deleting {self.real_save_folder}...") - shutil.rmtree(self.real_save_folder) - try: - # 如果是 emtpy 的,就会被删除掉 - os.rmdir(self.save_folder) - except: - pass - elif hasattr(self, 'buffer'): - self.buffer.close() - del self.buffer - - def on_exception(self, trainer, exception): - if self.delete_after_after: - if self.real_save_folder: # 这里,谁处异常,谁删除 + if self.real_save_folder: logger.info(f"Deleting {self.real_save_folder}...") - shutil.rmtree(self.real_save_folder) + shutil.rmtree(self.real_save_folder, ignore_errors=True) try: # 如果是 emtpy 的,就会被删除掉 os.rmdir(self.save_folder) diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index 6c015bdf..b5800134 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -31,8 +31,8 @@ class MoreEvaluateCallback(HasMonitorCallback): :param dataloaders: 需要评估的数据 :param metrics: 使用的 metrics 。 - :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch - evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 + :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch evaluate 一次;(2) 为正整数则表示每隔几个 batch + evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 @@ -128,7 +128,7 @@ class MoreEvaluateCallback(HasMonitorCallback): results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) self.topk_saver.get_monitor_value(results) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if self.is_better_results(results, keep_if_better=True): results = self.evaluator.run() self.topk_saver.save_topk(trainer, results) @@ -137,8 +137,8 @@ class MoreEvaluateCallback(HasMonitorCallback): if self.watch_monitor is not None: return if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: - validate_every = -self.evaluate_every - if trainer.cur_epoch_idx % validate_every == 0: + evaluate_every = -self.evaluate_every + if trainer.cur_epoch_idx % evaluate_every == 0: results = self.evaluator.run() self.topk_saver.save_topk(trainer, results) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index a6f82896..bacdea48 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -100,7 +100,7 @@ class RichCallback(ProgressCallback): self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', advance=self.epoch_bar_update_advance, refresh=True) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if len(results)==0: return rule_style = '' @@ -122,9 +122,6 @@ class RichCallback(ProgressCallback): else: self.progress_bar.print(results) - def on_exception(self, trainer, exception): - self.clear_tasks() - def clear_tasks(self): for key, taskid in self.task2id.items(): self.progress_bar.destroy_task(taskid) @@ -178,7 +175,7 @@ class RawTextCallback(ProgressCallback): f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' logger.info(text) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if len(results)==0: return base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index cfb54111..ef05e0c4 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -43,7 +43,7 @@ class TrainBatchLoop(Loop): trainer.check_batch_step_fn() trainer.on_train_batch_end() - trainer.step_validate() + trainer.step_evaluate() trainer.batch_idx_in_epoch = 0 @staticmethod diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index cbec1a01..307901b1 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -339,11 +339,11 @@ class Trainer(TrainerEventTrigger): self.num_batches_per_epoch = len(self.dataloader) self.total_batches = self.num_batches_per_epoch * self.n_epochs self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch - self.on_train_begin() - self.driver.barrier() - self.driver.zero_grad(self.set_grad_to_none) try: + self.on_train_begin() + self.driver.barrier() + self.driver.zero_grad(self.set_grad_to_none) while self.cur_epoch_idx < self.n_epochs: # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch @@ -356,10 +356,8 @@ class Trainer(TrainerEventTrigger): self.cur_epoch_idx += 1 self.on_train_epoch_end() self.driver.barrier() - self.epoch_validate() + self.epoch_evaluate() self.driver.barrier() - self.on_train_end() - self.driver.barrier() except EarlyStopException as e: logger.info(f"Catch early stop exception: {e.msg}.") @@ -373,17 +371,20 @@ class Trainer(TrainerEventTrigger): self.driver.on_exception() self.on_exception(e) raise e + finally: + self.on_train_end() + self.driver.barrier() def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): - def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: - trainer.on_validate_begin() - _validate_res: dict = validate_fn() - trainer.on_validate_end(_validate_res) + def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: + trainer.on_evaluate_begin() + _evaluate_res: dict = evaluate_fn() + trainer.on_evaluate_end(_evaluate_res) if self.evaluator is not None: - self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) + self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) - def step_validate(self): + def step_evaluate(self): """ 在每个 batch 结束后调用,根据设置执行 evaluate 。 @@ -396,7 +397,7 @@ class Trainer(TrainerEventTrigger): elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: self.run_evaluate() - def epoch_validate(self): + def epoch_evaluate(self): """ 在每个 epoch 结束后调用,根据设置执行 evaluate 。 @@ -404,8 +405,8 @@ class Trainer(TrainerEventTrigger): """ if self.evaluator is not None: if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: - validate_every = -self.evaluate_every - if self.cur_epoch_idx % validate_every == 0: + evaluate_every = -self.evaluate_every + if self.cur_epoch_idx % evaluate_every == 0: self.run_evaluate() def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index cc7a1b66..a2b2d5ae 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -81,12 +81,12 @@ class TrainerEventTrigger: def on_after_zero_grad(self, optimizers): self.callback_manager.on_after_zero_grad(self, optimizers) - def on_validate_begin(self): - self.callback_manager.on_validate_begin(self) + def on_evaluate_begin(self): + self.callback_manager.on_evaluate_begin(self) - def on_validate_end(self, results): + def on_evaluate_end(self, results): self.trainer_state.save_on_this_step = True - self.callback_manager.on_validate_end(self, results) + self.callback_manager.on_evaluate_end(self, results) class _TruncatedDataLoader: @@ -126,8 +126,8 @@ class _TruncatedDataLoader: return getattr(self.dataloader, item) -def check_evaluate_every(validate_every): - if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): +def check_evaluate_every(evaluate_every): + if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") - if callable(validate_every): - _check_valid_parameters_number(validate_every, expected_params=['trainer']) + if callable(evaluate_every): + _check_valid_parameters_number(evaluate_every, expected_params=['trainer']) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 84e3f002..bcebc6d0 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -63,7 +63,7 @@ class JittorDriver(Driver): def check_evaluator_mode(self, mode: str): model = self.unwrap_model() - if mode == "validate": + if mode == "evaluate": if not hasattr(model, "evaluate_step"): if hasattr(model, "test_step"): logger.warning_once( diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 086089ea..bdfc299f 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -173,6 +173,19 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): kwargs["extra"] = extra return kwargs + def setLevel(self, level) -> None: + """ + 设置当前 logger 以及其 handler 的 log 级别 + + :param level: + :return: + """ + if isinstance(level, str): + level = level.upper() + super().setLevel(level) + for handler in self.handlers: + handler.setLevel(level) + def _get_level(level): if not isinstance(level, int): diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index c3a9d4da..4fd5b654 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -38,7 +38,7 @@ class RecordMetricCallback(Callback): self.metric_threshold = metric_threshold self.metric_begin_value = None - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): self.metric = results[self.monitor] if self.metric_begin_value is None: self.metric_begin_value = self.metric @@ -113,11 +113,11 @@ class RecordTrainerEventTriggerCallback(Callback): def on_after_zero_grad(self, trainer, optimizers): print("on_after_zero_grad") - def on_validate_begin(self, trainer): - print("on_validate_begin") + def on_evaluate_begin(self, trainer): + print("on_evaluate_begin") - def on_validate_end(self, trainer, results): - print("on_validate_end") + def on_evaluate_end(self, trainer, results): + print("on_evaluate_end")