|
|
@@ -219,6 +219,85 @@ def test_trainer_event_trigger_2( |
|
|
|
assert member.value in output[0] |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) |
|
|
|
@pytest.mark.torch |
|
|
|
@magic_argv_env_context |
|
|
|
def test_trainer_event_trigger_3( |
|
|
|
model_and_optimizers: TrainerParameters, |
|
|
|
driver, |
|
|
|
device, |
|
|
|
n_epochs=2, |
|
|
|
): |
|
|
|
import re |
|
|
|
|
|
|
|
once_message_1 = "This message should be typed 1 times." |
|
|
|
once_message_2 = "test_filter_fn" |
|
|
|
once_message_3 = "once message 3" |
|
|
|
twice_message = "twice message hei hei" |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_epoch_begin(every=2)) |
|
|
|
def train_epoch_begin_1(trainer): |
|
|
|
print(once_message_1) |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_epoch_begin()) |
|
|
|
def train_epoch_begin_2(trainer): |
|
|
|
print(twice_message) |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_epoch_begin(once=2)) |
|
|
|
def train_epoch_begin_3(trainer): |
|
|
|
print(once_message_3) |
|
|
|
|
|
|
|
def filter_fn(filter, trainer): |
|
|
|
if trainer.cur_epoch_idx == 1: |
|
|
|
return True |
|
|
|
else: |
|
|
|
return False |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_epoch_end(filter_fn=filter_fn)) |
|
|
|
def test_filter_fn(trainer): |
|
|
|
print(once_message_2) |
|
|
|
|
|
|
|
|
|
|
|
with Capturing() as output: |
|
|
|
trainer = Trainer( |
|
|
|
model=model_and_optimizers.model, |
|
|
|
driver=driver, |
|
|
|
device=device, |
|
|
|
optimizers=model_and_optimizers.optimizers, |
|
|
|
train_dataloader=model_and_optimizers.train_dataloader, |
|
|
|
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, |
|
|
|
input_mapping=model_and_optimizers.input_mapping, |
|
|
|
output_mapping=model_and_optimizers.output_mapping, |
|
|
|
metrics=model_and_optimizers.metrics, |
|
|
|
|
|
|
|
n_epochs=n_epochs, |
|
|
|
) |
|
|
|
|
|
|
|
trainer.run() |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
|
|
|
once_pattern_1 = re.compile(once_message_1) |
|
|
|
once_pattern_2 = re.compile(once_message_2) |
|
|
|
once_pattern_3 = re.compile(once_message_3) |
|
|
|
twice_pattern = re.compile(twice_message) |
|
|
|
|
|
|
|
once_res_1 = once_pattern_1.findall(output[0]) |
|
|
|
assert len(once_res_1) == 1 |
|
|
|
once_res_2 = once_pattern_2.findall(output[0]) |
|
|
|
assert len(once_res_2) == 1 |
|
|
|
once_res_3 = once_pattern_3.findall(output[0]) |
|
|
|
assert len(once_res_3) == 1 |
|
|
|
twice_res = twice_pattern.findall(output[0]) |
|
|
|
assert len(twice_res) == 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|