diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 1c805ac2..7a25c45a 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -74,30 +74,30 @@ class EventEnum(_SingleEventState, Enum): @unique class Events(EventEnum): - ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized" - ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin" - ON_SANITY_CHECK_END = "on_sanity_check_end" - ON_TRAIN_BEGIN = "on_train_begin" - ON_TRAIN_END = "on_train_end" - ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin" - ON_TRAIN_EPOCH_END = "on_train_epoch_end" - ON_FETCH_DATA_BEGIN = "on_fetch_data_begin" - ON_FETCH_DATA_END = "on_fetch_data_end" - ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin" - ON_TRAIN_BATCH_END = "on_train_batch_end" - ON_EXCEPTION = "on_exception" - ON_SAVE_MODEL = "on_save_model" - ON_LOAD_MODEL = "on_load_model" - ON_SAVE_CHECKPOINT = "on_save_checkpoint" - ON_LOAD_CHECKPOINT = "on_load_checkpoint" - ON_BEFORE_BACKWARD = "on_before_backward" - ON_AFTER_BACKWARD = "on_after_backward" - ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" - ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" - ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" - ON_AFTER_ZERO_GRAD = "on_after_zero_grad" - ON_VALIDATE_BEGIN = "on_validate_begin" - ON_VALIDATE_END = "on_validate_end" + on_after_trainer_initialized = "on_after_trainer_initialized" + on_sanity_check_begin = "on_sanity_check_begin" + on_sanity_check_end = "on_sanity_check_end" + on_train_begin = "on_train_begin" + on_train_end = "on_train_end" + on_train_epoch_begin = "on_train_epoch_begin" + on_train_epoch_end = "on_train_epoch_end" + on_fetch_data_begin = "on_fetch_data_begin" + on_fetch_data_end = "on_fetch_data_end" + on_train_batch_begin = "on_train_batch_begin" + on_train_batch_end = "on_train_batch_end" + on_exception = "on_exception" + on_save_model = "on_save_model" + on_load_model = "on_load_model" + on_save_checkpoint = "on_save_checkpoint" + on_load_checkpoint = "on_load_checkpoint" + on_before_backward = "on_before_backward" + on_after_backward = "on_after_backward" + on_before_optimizers_step = "on_before_optimizers_step" + on_after_optimizers_step = "on_after_optimizers_step" + on_before_zero_grad = "on_before_zero_grad" + on_after_zero_grad = "on_after_zero_grad" + on_validate_begin = "on_validate_begin" + on_validate_end = "on_validate_end" class EventsList: diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 557c31b2..fe0a3582 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2( from fastNLP.core.callbacks.callback_events import Events - @Trainer.on(Events.ON_TRAIN_EPOCH_END) + @Trainer.on(Events.on_train_epoch_end) def raise_exception(trainer): if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: raise NotImplementedError diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 0da8c976..82fa3af0 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -254,7 +254,7 @@ def test_trainer_on_exception( ): from fastNLP.core.callbacks.callback_events import Events - @Trainer.on(Events.ON_TRAIN_EPOCH_END) + @Trainer.on(Events.on_train_epoch_end) def raise_exception(trainer): if trainer.driver.get_local_rank() == cur_rank: raise NotImplementedError