Browse Source

修改所有的 validate 为 evaluate ; 移动 callback.on_train_end()

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
f74b9b6bec
15 changed files with 99 additions and 73 deletions
  1. +32
    -4
      fastNLP/core/callbacks/callback.py
  2. +2
    -2
      fastNLP/core/callbacks/callback_events.py
  3. +2
    -2
      fastNLP/core/callbacks/callback_manager.py
  4. +1
    -1
      fastNLP/core/callbacks/checkpoint_callback.py
  5. +4
    -4
      fastNLP/core/callbacks/early_stop_callback.py
  6. +1
    -1
      fastNLP/core/callbacks/has_monitor_callback.py
  7. +6
    -19
      fastNLP/core/callbacks/load_best_model_callback.py
  8. +5
    -5
      fastNLP/core/callbacks/more_evaluate_callback.py
  9. +2
    -5
      fastNLP/core/callbacks/progress_callback.py
  10. +1
    -1
      fastNLP/core/controllers/loops/train_batch_loop.py
  11. +16
    -15
      fastNLP/core/controllers/trainer.py
  12. +8
    -8
      fastNLP/core/controllers/utils/utils.py
  13. +1
    -1
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  14. +13
    -0
      fastNLP/core/log/logger.py
  15. +5
    -5
      tests/helpers/callbacks/helper_callbacks.py

+ 32
- 4
fastNLP/core/callbacks/callback.py View File

@@ -12,6 +12,34 @@ from fastNLP.core.callbacks.callback_events import _SingleEventState
class Callback:
r"""
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类;
callback 调用时机顺序大概如下
Trainer.__init__():
on_after_trainer_initialized()
Trainer.run():
if num_eval_sanity_batch>0:
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch
on_sanity_check_end()
try:
on_train_begin()
while cur_epoch_idx < n_epochs:
on_train_epoch_begin()
while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin()
on_fetch_data_end()
on_train_batch_begin()
on_before_backward()
on_after_backward()
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_train_batch_end()
on_train_epoch_end()
except BaseException:
self.on_exception()
finally:
on_train_end()
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将
"""

def on_after_trainer_initialized(self, trainer, driver):
@@ -221,9 +249,9 @@ class Callback:
"""
pass

def on_validate_begin(self, trainer):
def on_evaluate_begin(self, trainer):
"""
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后
在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。

:param trainer:
@@ -231,9 +259,9 @@ class Callback:
"""
pass

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
"""
结束 validate 时调用,并把 validate 的结果传入。
结束 evaluate 时调用,并把 evaluate 的结果传入。

:param trainer:
:param results: Evaluate 的结果,一般是个 dict 。


+ 2
- 2
fastNLP/core/callbacks/callback_events.py View File

@@ -96,8 +96,8 @@ class Events(EventEnum):
on_after_optimizers_step = "on_after_optimizers_step"
on_before_zero_grad = "on_before_zero_grad"
on_after_zero_grad = "on_after_zero_grad"
on_validate_begin = "on_validate_begin"
on_validate_end = "on_validate_end"
on_evaluate_begin = "on_evaluate_begin"
on_evaluate_end = "on_evaluate_end"


class EventsList:


+ 2
- 2
fastNLP/core/callbacks/callback_manager.py View File

@@ -281,9 +281,9 @@ class CallbackManager:
pass

@_transfer
def on_validate_begin(self, trainer):
def on_evaluate_begin(self, trainer):
pass

@_transfer
def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
pass

+ 1
- 1
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -114,7 +114,7 @@ class CheckpointCallback(Callback):
if self.topk_saver.topk_queue and trainer.evaluator is None:
logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.")

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
# 如果发生了保存,则返回的 folder 不为 None
folder = self.topk_saver.save_topk(trainer, results)



+ 4
- 4
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -16,13 +16,13 @@ class EarlyStopCallback(HasMonitorCallback):
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。
:param patience: 多少次 validate 不没有提升就停止。
:param patience: 多少次 evaluate 不没有提升就停止。
"""
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
self.wait = 0
self.patience = patience

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
@@ -32,13 +32,13 @@ class EarlyStopCallback(HasMonitorCallback):
self.wait += 1

def on_fetch_data_begin(self, trainer):
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。
# 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。
if self.wait >= self.patience:
raise EarlyStopException(f"After {self.wait} validations, no improvement for "
f"metric `{self._real_monitor}`")

def on_train_epoch_begin(self, trainer):
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。
# 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。
if self.wait >= self.patience:
raise EarlyStopException(f"After {self.wait} validations, no improvement for "
f"metric `{self._real_monitor}`(best value: {self.monitor_value})")


+ 1
- 1
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -216,6 +216,6 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback):
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
self.execute_fn = execute_fn

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if self.is_better_results(results):
self.execute_fn()

+ 6
- 19
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -76,7 +76,7 @@ class LoadBestModelCallback(HasMonitorCallback):

super().on_after_trainer_initialized(trainer, driver)

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True):
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
@@ -95,27 +95,14 @@ class LoadBestModelCallback(HasMonitorCallback):
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)

trainer.driver.barrier()
self._delete_after_after(trainer)

def _delete_after_after(self, trainer):
trainer.driver.barrier()
if self.delete_after_after:
if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
# 只需要 rank 0 执行删除。
logger.info(f"Deleting {self.real_save_folder}...")
shutil.rmtree(self.real_save_folder)
try:
# 如果是 emtpy 的,就会被删除掉
os.rmdir(self.save_folder)
except:
pass
elif hasattr(self, 'buffer'):
self.buffer.close()
del self.buffer

def on_exception(self, trainer, exception):
if self.delete_after_after:
if self.real_save_folder: # 这里,谁处异常,谁删除
if self.real_save_folder:
logger.info(f"Deleting {self.real_save_folder}...")
shutil.rmtree(self.real_save_folder)
shutil.rmtree(self.real_save_folder, ignore_errors=True)
try:
# 如果是 emtpy 的,就会被删除掉
os.rmdir(self.save_folder)


+ 5
- 5
fastNLP/core/callbacks/more_evaluate_callback.py View File

@@ -31,8 +31,8 @@ class MoreEvaluateCallback(HasMonitorCallback):

:param dataloaders: 需要评估的数据
:param metrics: 使用的 metrics 。
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch evaluate 一次;(2) 为正整数则表示每隔几个 batch
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回
一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种
@@ -128,7 +128,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch)
self.topk_saver.get_monitor_value(results)

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True):
results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results)
@@ -137,8 +137,8 @@ class MoreEvaluateCallback(HasMonitorCallback):
if self.watch_monitor is not None:
return
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
validate_every = -self.evaluate_every
if trainer.cur_epoch_idx % validate_every == 0:
evaluate_every = -self.evaluate_every
if trainer.cur_epoch_idx % evaluate_every == 0:
results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results)



+ 2
- 5
fastNLP/core/callbacks/progress_callback.py View File

@@ -100,7 +100,7 @@ class RichCallback(ProgressCallback):
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}',
advance=self.epoch_bar_update_advance, refresh=True)

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if len(results)==0:
return
rule_style = ''
@@ -122,9 +122,6 @@ class RichCallback(ProgressCallback):
else:
self.progress_bar.print(results)

def on_exception(self, trainer, exception):
self.clear_tasks()

def clear_tasks(self):
for key, taskid in self.task2id.items():
self.progress_bar.destroy_task(taskid)
@@ -178,7 +175,7 @@ class RawTextCallback(ProgressCallback):
f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.'
logger.info(text)

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if len(results)==0:
return
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'


+ 1
- 1
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -43,7 +43,7 @@ class TrainBatchLoop(Loop):

trainer.check_batch_step_fn()
trainer.on_train_batch_end()
trainer.step_validate()
trainer.step_evaluate()
trainer.batch_idx_in_epoch = 0

@staticmethod


+ 16
- 15
fastNLP/core/controllers/trainer.py View File

@@ -339,11 +339,11 @@ class Trainer(TrainerEventTrigger):
self.num_batches_per_epoch = len(self.dataloader)
self.total_batches = self.num_batches_per_epoch * self.n_epochs
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch
self.on_train_begin()
self.driver.barrier()
self.driver.zero_grad(self.set_grad_to_none)

try:
self.on_train_begin()
self.driver.barrier()
self.driver.zero_grad(self.set_grad_to_none)
while self.cur_epoch_idx < self.n_epochs:
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
@@ -356,10 +356,8 @@ class Trainer(TrainerEventTrigger):
self.cur_epoch_idx += 1
self.on_train_epoch_end()
self.driver.barrier()
self.epoch_validate()
self.epoch_evaluate()
self.driver.barrier()
self.on_train_end()
self.driver.barrier()

except EarlyStopException as e:
logger.info(f"Catch early stop exception: {e.msg}.")
@@ -373,17 +371,20 @@ class Trainer(TrainerEventTrigger):
self.driver.on_exception()
self.on_exception(e)
raise e
finally:
self.on_train_end()
self.driver.barrier()

def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None:
trainer.on_validate_begin()
_validate_res: dict = validate_fn()
trainer.on_validate_end(_validate_res)
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
trainer.on_evaluate_begin()
_evaluate_res: dict = evaluate_fn()
trainer.on_evaluate_end(_evaluate_res)

if self.evaluator is not None:
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))
self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))

def step_validate(self):
def step_evaluate(self):
"""
在每个 batch 结束后调用,根据设置执行 evaluate 。

@@ -396,7 +397,7 @@ class Trainer(TrainerEventTrigger):
elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0:
self.run_evaluate()

def epoch_validate(self):
def epoch_evaluate(self):
"""
在每个 epoch 结束后调用,根据设置执行 evaluate 。

@@ -404,8 +405,8 @@ class Trainer(TrainerEventTrigger):
"""
if self.evaluator is not None:
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
validate_every = -self.evaluate_every
if self.cur_epoch_idx % validate_every == 0:
evaluate_every = -self.evaluate_every
if self.cur_epoch_idx % evaluate_every == 0:
self.run_evaluate()

def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):


+ 8
- 8
fastNLP/core/controllers/utils/utils.py View File

@@ -81,12 +81,12 @@ class TrainerEventTrigger:
def on_after_zero_grad(self, optimizers):
self.callback_manager.on_after_zero_grad(self, optimizers)

def on_validate_begin(self):
self.callback_manager.on_validate_begin(self)
def on_evaluate_begin(self):
self.callback_manager.on_evaluate_begin(self)

def on_validate_end(self, results):
def on_evaluate_end(self, results):
self.trainer_state.save_on_this_step = True
self.callback_manager.on_validate_end(self, results)
self.callback_manager.on_evaluate_end(self, results)


class _TruncatedDataLoader:
@@ -126,8 +126,8 @@ class _TruncatedDataLoader:
return getattr(self.dataloader, item)


def check_evaluate_every(validate_every):
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
def check_evaluate_every(evaluate_every):
if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0):
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):
_check_valid_parameters_number(validate_every, expected_params=['trainer'])
if callable(evaluate_every):
_check_valid_parameters_number(evaluate_every, expected_params=['trainer'])

+ 1
- 1
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -63,7 +63,7 @@ class JittorDriver(Driver):

def check_evaluator_mode(self, mode: str):
model = self.unwrap_model()
if mode == "validate":
if mode == "evaluate":
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"):
logger.warning_once(


+ 13
- 0
fastNLP/core/log/logger.py View File

@@ -173,6 +173,19 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
kwargs["extra"] = extra
return kwargs

def setLevel(self, level) -> None:
"""
设置当前 logger 以及其 handler 的 log 级别

:param level:
:return:
"""
if isinstance(level, str):
level = level.upper()
super().setLevel(level)
for handler in self.handlers:
handler.setLevel(level)


def _get_level(level):
if not isinstance(level, int):


+ 5
- 5
tests/helpers/callbacks/helper_callbacks.py View File

@@ -38,7 +38,7 @@ class RecordMetricCallback(Callback):
self.metric_threshold = metric_threshold
self.metric_begin_value = None

def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
self.metric = results[self.monitor]
if self.metric_begin_value is None:
self.metric_begin_value = self.metric
@@ -113,11 +113,11 @@ class RecordTrainerEventTriggerCallback(Callback):
def on_after_zero_grad(self, trainer, optimizers):
print("on_after_zero_grad")

def on_validate_begin(self, trainer):
print("on_validate_begin")
def on_evaluate_begin(self, trainer):
print("on_evaluate_begin")

def on_validate_end(self, trainer, results):
print("on_validate_end")
def on_evaluate_end(self, trainer, results):
print("on_evaluate_end")





Loading…
Cancel
Save