Browse Source

添加了测试 函数式 callback

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
30af3b032f
1 changed files with 124 additions and 0 deletions
  1. +124
    -0
      tests/core/controllers/test_trainer_event_trigger.py

+ 124
- 0
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -111,6 +111,126 @@ def test_trainer_event_trigger_2(
n_epochs=2, n_epochs=2,
): ):


@Trainer.on(Events.on_after_trainer_initialized())
def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized")

@Trainer.on(Events.on_sanity_check_begin())
def on_sanity_check_begin(trainer):
print("on_sanity_check_begin")

@Trainer.on(Events.on_sanity_check_end())
def on_sanity_check_end(trainer, sanity_check_res):
print("on_sanity_check_end")

@Trainer.on(Events.on_train_begin())
def on_train_begin(trainer):
print("on_train_begin")

@Trainer.on(Events.on_train_end())
def on_train_end(trainer):
print("on_train_end")

@Trainer.on(Events.on_train_epoch_begin())
def on_train_epoch_begin(trainer):
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")

@Trainer.on(Events.on_train_epoch_end())
def on_train_epoch_end(trainer):
print("on_train_epoch_end")

@Trainer.on(Events.on_fetch_data_begin())
def on_fetch_data_begin(trainer):
print("on_fetch_data_begin")

@Trainer.on(Events.on_fetch_data_end())
def on_fetch_data_end(trainer):
print("on_fetch_data_end")

@Trainer.on(Events.on_train_batch_begin())
def on_train_batch_begin(trainer, batch, indices=None):
print("on_train_batch_begin")

@Trainer.on(Events.on_train_batch_end())
def on_train_batch_end(trainer):
print("on_train_batch_end")

@Trainer.on(Events.on_exception())
def on_exception(trainer, exception):
print("on_exception")

@Trainer.on(Events.on_before_backward())
def on_before_backward(trainer, outputs):
print("on_before_backward")

@Trainer.on(Events.on_after_backward())
def on_after_backward(trainer):
print("on_after_backward")

@Trainer.on(Events.on_before_optimizers_step())
def on_before_optimizers_step(trainer, optimizers):
print("on_before_optimizers_step")

@Trainer.on(Events.on_after_optimizers_step())
def on_after_optimizers_step(trainer, optimizers):
print("on_after_optimizers_step")

@Trainer.on(Events.on_before_zero_grad())
def on_before_zero_grad(trainer, optimizers):
print("on_before_zero_grad")

@Trainer.on(Events.on_after_zero_grad())
def on_after_zero_grad(trainer, optimizers):
print("on_after_zero_grad")

@Trainer.on(Events.on_evaluate_begin())
def on_evaluate_begin(trainer):
print("on_evaluate_begin")

@Trainer.on(Events.on_evaluate_end())
def on_evaluate_end(trainer, results):
print("on_evaluate_end")

with pytest.raises(Exception):
with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()

for name, member in Events.__members__.items():
assert member.value in output[0]




@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_3(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):

@Trainer.on(Events.on_after_trainer_initialized) @Trainer.on(Events.on_after_trainer_initialized)
def on_after_trainer_initialized(trainer, driver): def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized") print("on_after_trainer_initialized")
@@ -224,3 +344,7 @@ def test_trainer_event_trigger_2(











Loading…
Cancel
Save