| @@ -114,6 +114,126 @@ def test_trainer_event_trigger_2( | |||||
| n_epochs=2, | n_epochs=2, | ||||
| ): | ): | ||||
| @Trainer.on(Events.on_after_trainer_initialized()) | |||||
| def on_after_trainer_initialized(trainer, driver): | |||||
| print("on_after_trainer_initialized") | |||||
| @Trainer.on(Events.on_sanity_check_begin()) | |||||
| def on_sanity_check_begin(trainer): | |||||
| print("on_sanity_check_begin") | |||||
| @Trainer.on(Events.on_sanity_check_end()) | |||||
| def on_sanity_check_end(trainer, sanity_check_res): | |||||
| print("on_sanity_check_end") | |||||
| @Trainer.on(Events.on_train_begin()) | |||||
| def on_train_begin(trainer): | |||||
| print("on_train_begin") | |||||
| @Trainer.on(Events.on_train_end()) | |||||
| def on_train_end(trainer): | |||||
| print("on_train_end") | |||||
| @Trainer.on(Events.on_train_epoch_begin()) | |||||
| def on_train_epoch_begin(trainer): | |||||
| if trainer.cur_epoch_idx >= 1: | |||||
| # 触发 on_exception; | |||||
| raise Exception | |||||
| print("on_train_epoch_begin") | |||||
| @Trainer.on(Events.on_train_epoch_end()) | |||||
| def on_train_epoch_end(trainer): | |||||
| print("on_train_epoch_end") | |||||
| @Trainer.on(Events.on_fetch_data_begin()) | |||||
| def on_fetch_data_begin(trainer): | |||||
| print("on_fetch_data_begin") | |||||
| @Trainer.on(Events.on_fetch_data_end()) | |||||
| def on_fetch_data_end(trainer): | |||||
| print("on_fetch_data_end") | |||||
| @Trainer.on(Events.on_train_batch_begin()) | |||||
| def on_train_batch_begin(trainer, batch, indices=None): | |||||
| print("on_train_batch_begin") | |||||
| @Trainer.on(Events.on_train_batch_end()) | |||||
| def on_train_batch_end(trainer): | |||||
| print("on_train_batch_end") | |||||
| @Trainer.on(Events.on_exception()) | |||||
| def on_exception(trainer, exception): | |||||
| print("on_exception") | |||||
| @Trainer.on(Events.on_before_backward()) | |||||
| def on_before_backward(trainer, outputs): | |||||
| print("on_before_backward") | |||||
| @Trainer.on(Events.on_after_backward()) | |||||
| def on_after_backward(trainer): | |||||
| print("on_after_backward") | |||||
| @Trainer.on(Events.on_before_optimizers_step()) | |||||
| def on_before_optimizers_step(trainer, optimizers): | |||||
| print("on_before_optimizers_step") | |||||
| @Trainer.on(Events.on_after_optimizers_step()) | |||||
| def on_after_optimizers_step(trainer, optimizers): | |||||
| print("on_after_optimizers_step") | |||||
| @Trainer.on(Events.on_before_zero_grad()) | |||||
| def on_before_zero_grad(trainer, optimizers): | |||||
| print("on_before_zero_grad") | |||||
| @Trainer.on(Events.on_after_zero_grad()) | |||||
| def on_after_zero_grad(trainer, optimizers): | |||||
| print("on_after_zero_grad") | |||||
| @Trainer.on(Events.on_evaluate_begin()) | |||||
| def on_evaluate_begin(trainer): | |||||
| print("on_evaluate_begin") | |||||
| @Trainer.on(Events.on_evaluate_end()) | |||||
| def on_evaluate_end(trainer, results): | |||||
| print("on_evaluate_end") | |||||
| with pytest.raises(Exception): | |||||
| with Capturing() as output: | |||||
| trainer = Trainer( | |||||
| model=model_and_optimizers.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=model_and_optimizers.optimizers, | |||||
| train_dataloader=model_and_optimizers.train_dataloader, | |||||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
| input_mapping=model_and_optimizers.input_mapping, | |||||
| output_mapping=model_and_optimizers.output_mapping, | |||||
| metrics=model_and_optimizers.metrics, | |||||
| n_epochs=n_epochs, | |||||
| ) | |||||
| trainer.run() | |||||
| if dist.is_initialized(): | |||||
| dist.destroy_process_group() | |||||
| for name, member in Events.__members__.items(): | |||||
| assert member.value in output[0] | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | |||||
| @pytest.mark.torch | |||||
| @magic_argv_env_context | |||||
| def test_trainer_event_trigger_3( | |||||
| model_and_optimizers: TrainerParameters, | |||||
| driver, | |||||
| device, | |||||
| n_epochs=2, | |||||
| ): | |||||
| @Trainer.on(Events.on_after_trainer_initialized) | @Trainer.on(Events.on_after_trainer_initialized) | ||||
| def on_after_trainer_initialized(trainer, driver): | def on_after_trainer_initialized(trainer, driver): | ||||
| print("on_after_trainer_initialized") | print("on_after_trainer_initialized") | ||||
| @@ -227,3 +347,7 @@ def test_trainer_event_trigger_2( | |||||