| @@ -219,6 +219,85 @@ def test_trainer_event_trigger_2( | |||
| assert member.value in output[0] | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) | |||
| @pytest.mark.torch | |||
| @magic_argv_env_context | |||
| def test_trainer_event_trigger_3( | |||
| model_and_optimizers: TrainerParameters, | |||
| driver, | |||
| device, | |||
| n_epochs=2, | |||
| ): | |||
| import re | |||
| once_message_1 = "This message should be typed 1 times." | |||
| once_message_2 = "test_filter_fn" | |||
| once_message_3 = "once message 3" | |||
| twice_message = "twice message hei hei" | |||
| @Trainer.on(Events.on_train_epoch_begin(every=2)) | |||
| def train_epoch_begin_1(trainer): | |||
| print(once_message_1) | |||
| @Trainer.on(Events.on_train_epoch_begin()) | |||
| def train_epoch_begin_2(trainer): | |||
| print(twice_message) | |||
| @Trainer.on(Events.on_train_epoch_begin(once=2)) | |||
| def train_epoch_begin_3(trainer): | |||
| print(once_message_3) | |||
| def filter_fn(filter, trainer): | |||
| if trainer.cur_epoch_idx == 1: | |||
| return True | |||
| else: | |||
| return False | |||
| @Trainer.on(Events.on_train_epoch_end(filter_fn=filter_fn)) | |||
| def test_filter_fn(trainer): | |||
| print(once_message_2) | |||
| with Capturing() as output: | |||
| trainer = Trainer( | |||
| model=model_and_optimizers.model, | |||
| driver=driver, | |||
| device=device, | |||
| optimizers=model_and_optimizers.optimizers, | |||
| train_dataloader=model_and_optimizers.train_dataloader, | |||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
| input_mapping=model_and_optimizers.input_mapping, | |||
| output_mapping=model_and_optimizers.output_mapping, | |||
| metrics=model_and_optimizers.metrics, | |||
| n_epochs=n_epochs, | |||
| ) | |||
| trainer.run() | |||
| if dist.is_initialized(): | |||
| dist.destroy_process_group() | |||
| once_pattern_1 = re.compile(once_message_1) | |||
| once_pattern_2 = re.compile(once_message_2) | |||
| once_pattern_3 = re.compile(once_message_3) | |||
| twice_pattern = re.compile(twice_message) | |||
| once_res_1 = once_pattern_1.findall(output[0]) | |||
| assert len(once_res_1) == 1 | |||
| once_res_2 = once_pattern_2.findall(output[0]) | |||
| assert len(once_res_2) == 1 | |||
| once_res_3 = once_pattern_3.findall(output[0]) | |||
| assert len(once_res_3) == 1 | |||
| twice_res = twice_pattern.findall(output[0]) | |||
| assert len(twice_res) == 2 | |||