From 30af3b032fd492f453c766d1117cb90f30b3efb5 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 3 May 2022 16:42:33 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=20=E5=87=BD=E6=95=B0=E5=BC=8F=20callback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/test_trainer_event_trigger.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index 84752287..fab07b3c 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -111,6 +111,126 @@ def test_trainer_event_trigger_2( n_epochs=2, ): + @Trainer.on(Events.on_after_trainer_initialized()) + def on_after_trainer_initialized(trainer, driver): + print("on_after_trainer_initialized") + + @Trainer.on(Events.on_sanity_check_begin()) + def on_sanity_check_begin(trainer): + print("on_sanity_check_begin") + + @Trainer.on(Events.on_sanity_check_end()) + def on_sanity_check_end(trainer, sanity_check_res): + print("on_sanity_check_end") + + @Trainer.on(Events.on_train_begin()) + def on_train_begin(trainer): + print("on_train_begin") + + @Trainer.on(Events.on_train_end()) + def on_train_end(trainer): + print("on_train_end") + + @Trainer.on(Events.on_train_epoch_begin()) + def on_train_epoch_begin(trainer): + if trainer.cur_epoch_idx >= 1: + # 触发 on_exception; + raise Exception + print("on_train_epoch_begin") + + @Trainer.on(Events.on_train_epoch_end()) + def on_train_epoch_end(trainer): + print("on_train_epoch_end") + + @Trainer.on(Events.on_fetch_data_begin()) + def on_fetch_data_begin(trainer): + print("on_fetch_data_begin") + + @Trainer.on(Events.on_fetch_data_end()) + def on_fetch_data_end(trainer): + print("on_fetch_data_end") + + @Trainer.on(Events.on_train_batch_begin()) + def on_train_batch_begin(trainer, batch, indices=None): + print("on_train_batch_begin") + + @Trainer.on(Events.on_train_batch_end()) + def on_train_batch_end(trainer): + print("on_train_batch_end") + + @Trainer.on(Events.on_exception()) + def on_exception(trainer, exception): + print("on_exception") + + @Trainer.on(Events.on_before_backward()) + def on_before_backward(trainer, outputs): + print("on_before_backward") + + @Trainer.on(Events.on_after_backward()) + def on_after_backward(trainer): + print("on_after_backward") + + @Trainer.on(Events.on_before_optimizers_step()) + def on_before_optimizers_step(trainer, optimizers): + print("on_before_optimizers_step") + + @Trainer.on(Events.on_after_optimizers_step()) + def on_after_optimizers_step(trainer, optimizers): + print("on_after_optimizers_step") + + @Trainer.on(Events.on_before_zero_grad()) + def on_before_zero_grad(trainer, optimizers): + print("on_before_zero_grad") + + @Trainer.on(Events.on_after_zero_grad()) + def on_after_zero_grad(trainer, optimizers): + print("on_after_zero_grad") + + @Trainer.on(Events.on_evaluate_begin()) + def on_evaluate_begin(trainer): + print("on_evaluate_begin") + + @Trainer.on(Events.on_evaluate_end()) + def on_evaluate_end(trainer, results): + print("on_evaluate_end") + + with pytest.raises(Exception): + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + for name, member in Events.__members__.items(): + assert member.value in output[0] + + + + +@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_event_trigger_3( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + @Trainer.on(Events.on_after_trainer_initialized) def on_after_trainer_initialized(trainer, driver): print("on_after_trainer_initialized") @@ -224,3 +344,7 @@ def test_trainer_event_trigger_2( + + + +