Browse Source

添加了 test_trainer_event_trigger_3 的测试

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
3d4c318f0e
1 changed files with 79 additions and 0 deletions
  1. +79
    -0
      tests/core/controllers/test_trainer_event_trigger.py

+ 79
- 0
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -219,6 +219,85 @@ def test_trainer_event_trigger_2(
assert member.value in output[0] assert member.value in output[0]




@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_3(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):
import re

once_message_1 = "This message should be typed 1 times."
once_message_2 = "test_filter_fn"
once_message_3 = "once message 3"
twice_message = "twice message hei hei"

@Trainer.on(Events.on_train_epoch_begin(every=2))
def train_epoch_begin_1(trainer):
print(once_message_1)

@Trainer.on(Events.on_train_epoch_begin())
def train_epoch_begin_2(trainer):
print(twice_message)

@Trainer.on(Events.on_train_epoch_begin(once=2))
def train_epoch_begin_3(trainer):
print(once_message_3)

def filter_fn(filter, trainer):
if trainer.cur_epoch_idx == 1:
return True
else:
return False

@Trainer.on(Events.on_train_epoch_end(filter_fn=filter_fn))
def test_filter_fn(trainer):
print(once_message_2)


with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()


once_pattern_1 = re.compile(once_message_1)
once_pattern_2 = re.compile(once_message_2)
once_pattern_3 = re.compile(once_message_3)
twice_pattern = re.compile(twice_message)

once_res_1 = once_pattern_1.findall(output[0])
assert len(once_res_1) == 1
once_res_2 = once_pattern_2.findall(output[0])
assert len(once_res_2) == 1
once_res_3 = once_pattern_3.findall(output[0])
assert len(once_res_3) == 1
twice_res = twice_pattern.findall(output[0])
assert len(twice_res) == 2














Loading…
Cancel
Save