Browse Source

将 Events 修改为小写

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
b9b0b53430
3 changed files with 26 additions and 26 deletions
  1. +24
    -24
      fastNLP/core/callbacks/callback_events.py
  2. +1
    -1
      tests/core/callbacks/test_checkpoint_callback_torch.py
  3. +1
    -1
      tests/core/controllers/test_trainer_wo_evaluator_torch.py

+ 24
- 24
fastNLP/core/callbacks/callback_events.py View File

@@ -74,30 +74,30 @@ class EventEnum(_SingleEventState, Enum):


@unique @unique
class Events(EventEnum): class Events(EventEnum):
ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized"
ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin"
ON_SANITY_CHECK_END = "on_sanity_check_end"
ON_TRAIN_BEGIN = "on_train_begin"
ON_TRAIN_END = "on_train_end"
ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin"
ON_TRAIN_EPOCH_END = "on_train_epoch_end"
ON_FETCH_DATA_BEGIN = "on_fetch_data_begin"
ON_FETCH_DATA_END = "on_fetch_data_end"
ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin"
ON_TRAIN_BATCH_END = "on_train_batch_end"
ON_EXCEPTION = "on_exception"
ON_SAVE_MODEL = "on_save_model"
ON_LOAD_MODEL = "on_load_model"
ON_SAVE_CHECKPOINT = "on_save_checkpoint"
ON_LOAD_CHECKPOINT = "on_load_checkpoint"
ON_BEFORE_BACKWARD = "on_before_backward"
ON_AFTER_BACKWARD = "on_after_backward"
ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step"
ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step"
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad"
ON_AFTER_ZERO_GRAD = "on_after_zero_grad"
ON_VALIDATE_BEGIN = "on_validate_begin"
ON_VALIDATE_END = "on_validate_end"
on_after_trainer_initialized = "on_after_trainer_initialized"
on_sanity_check_begin = "on_sanity_check_begin"
on_sanity_check_end = "on_sanity_check_end"
on_train_begin = "on_train_begin"
on_train_end = "on_train_end"
on_train_epoch_begin = "on_train_epoch_begin"
on_train_epoch_end = "on_train_epoch_end"
on_fetch_data_begin = "on_fetch_data_begin"
on_fetch_data_end = "on_fetch_data_end"
on_train_batch_begin = "on_train_batch_begin"
on_train_batch_end = "on_train_batch_end"
on_exception = "on_exception"
on_save_model = "on_save_model"
on_load_model = "on_load_model"
on_save_checkpoint = "on_save_checkpoint"
on_load_checkpoint = "on_load_checkpoint"
on_before_backward = "on_before_backward"
on_after_backward = "on_after_backward"
on_before_optimizers_step = "on_before_optimizers_step"
on_after_optimizers_step = "on_after_optimizers_step"
on_before_zero_grad = "on_before_zero_grad"
on_after_zero_grad = "on_after_zero_grad"
on_validate_begin = "on_validate_begin"
on_validate_end = "on_validate_end"




class EventsList: class EventsList:


+ 1
- 1
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2(


from fastNLP.core.callbacks.callback_events import Events from fastNLP.core.callbacks.callback_events import Events


@Trainer.on(Events.ON_TRAIN_EPOCH_END)
@Trainer.on(Events.on_train_epoch_end)
def raise_exception(trainer): def raise_exception(trainer):
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4:
raise NotImplementedError raise NotImplementedError


+ 1
- 1
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -254,7 +254,7 @@ def test_trainer_on_exception(
): ):
from fastNLP.core.callbacks.callback_events import Events from fastNLP.core.callbacks.callback_events import Events


@Trainer.on(Events.ON_TRAIN_EPOCH_END)
@Trainer.on(Events.on_train_epoch_end)
def raise_exception(trainer): def raise_exception(trainer):
if trainer.driver.get_local_rank() == cur_rank: if trainer.driver.get_local_rank() == cur_rank:
raise NotImplementedError raise NotImplementedError


Loading…
Cancel
Save