diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index bff1044c..d9eeb16f 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -114,6 +114,126 @@ def test_trainer_event_trigger_2( n_epochs=2, ): + @Trainer.on(Events.on_after_trainer_initialized()) + def on_after_trainer_initialized(trainer, driver): + print("on_after_trainer_initialized") + + @Trainer.on(Events.on_sanity_check_begin()) + def on_sanity_check_begin(trainer): + print("on_sanity_check_begin") + + @Trainer.on(Events.on_sanity_check_end()) + def on_sanity_check_end(trainer, sanity_check_res): + print("on_sanity_check_end") + + @Trainer.on(Events.on_train_begin()) + def on_train_begin(trainer): + print("on_train_begin") + + @Trainer.on(Events.on_train_end()) + def on_train_end(trainer): + print("on_train_end") + + @Trainer.on(Events.on_train_epoch_begin()) + def on_train_epoch_begin(trainer): + if trainer.cur_epoch_idx >= 1: + # 触发 on_exception; + raise Exception + print("on_train_epoch_begin") + + @Trainer.on(Events.on_train_epoch_end()) + def on_train_epoch_end(trainer): + print("on_train_epoch_end") + + @Trainer.on(Events.on_fetch_data_begin()) + def on_fetch_data_begin(trainer): + print("on_fetch_data_begin") + + @Trainer.on(Events.on_fetch_data_end()) + def on_fetch_data_end(trainer): + print("on_fetch_data_end") + + @Trainer.on(Events.on_train_batch_begin()) + def on_train_batch_begin(trainer, batch, indices=None): + print("on_train_batch_begin") + + @Trainer.on(Events.on_train_batch_end()) + def on_train_batch_end(trainer): + print("on_train_batch_end") + + @Trainer.on(Events.on_exception()) + def on_exception(trainer, exception): + print("on_exception") + + @Trainer.on(Events.on_before_backward()) + def on_before_backward(trainer, outputs): + print("on_before_backward") + + @Trainer.on(Events.on_after_backward()) + def on_after_backward(trainer): + print("on_after_backward") + + @Trainer.on(Events.on_before_optimizers_step()) + def on_before_optimizers_step(trainer, optimizers): + print("on_before_optimizers_step") + + @Trainer.on(Events.on_after_optimizers_step()) + def on_after_optimizers_step(trainer, optimizers): + print("on_after_optimizers_step") + + @Trainer.on(Events.on_before_zero_grad()) + def on_before_zero_grad(trainer, optimizers): + print("on_before_zero_grad") + + @Trainer.on(Events.on_after_zero_grad()) + def on_after_zero_grad(trainer, optimizers): + print("on_after_zero_grad") + + @Trainer.on(Events.on_evaluate_begin()) + def on_evaluate_begin(trainer): + print("on_evaluate_begin") + + @Trainer.on(Events.on_evaluate_end()) + def on_evaluate_end(trainer, results): + print("on_evaluate_end") + + with pytest.raises(Exception): + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + for name, member in Events.__members__.items(): + assert member.value in output[0] + + + + +@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_event_trigger_3( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + @Trainer.on(Events.on_after_trainer_initialized) def on_after_trainer_initialized(trainer, driver): print("on_after_trainer_initialized") @@ -227,3 +347,7 @@ def test_trainer_event_trigger_2( + + + +