diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index 1a90a96d..01f28a0b 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -219,6 +219,85 @@ def test_trainer_event_trigger_2( assert member.value in output[0] +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_event_trigger_3( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + import re + + once_message_1 = "This message should be typed 1 times." + once_message_2 = "test_filter_fn" + once_message_3 = "once message 3" + twice_message = "twice message hei hei" + + @Trainer.on(Events.on_train_epoch_begin(every=2)) + def train_epoch_begin_1(trainer): + print(once_message_1) + + @Trainer.on(Events.on_train_epoch_begin()) + def train_epoch_begin_2(trainer): + print(twice_message) + + @Trainer.on(Events.on_train_epoch_begin(once=2)) + def train_epoch_begin_3(trainer): + print(once_message_3) + + def filter_fn(filter, trainer): + if trainer.cur_epoch_idx == 1: + return True + else: + return False + + @Trainer.on(Events.on_train_epoch_end(filter_fn=filter_fn)) + def test_filter_fn(trainer): + print(once_message_2) + + + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + + once_pattern_1 = re.compile(once_message_1) + once_pattern_2 = re.compile(once_message_2) + once_pattern_3 = re.compile(once_message_3) + twice_pattern = re.compile(twice_message) + + once_res_1 = once_pattern_1.findall(output[0]) + assert len(once_res_1) == 1 + once_res_2 = once_pattern_2.findall(output[0]) + assert len(once_res_2) == 1 + once_res_3 = once_pattern_3.findall(output[0]) + assert len(once_res_3) == 1 + twice_res = twice_pattern.findall(output[0]) + assert len(twice_res) == 2 + + + + + +