|
|
@@ -111,6 +111,126 @@ def test_trainer_event_trigger_2( |
|
|
|
n_epochs=2, |
|
|
|
): |
|
|
|
|
|
|
|
@Trainer.on(Events.on_after_trainer_initialized()) |
|
|
|
def on_after_trainer_initialized(trainer, driver): |
|
|
|
print("on_after_trainer_initialized") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_sanity_check_begin()) |
|
|
|
def on_sanity_check_begin(trainer): |
|
|
|
print("on_sanity_check_begin") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_sanity_check_end()) |
|
|
|
def on_sanity_check_end(trainer, sanity_check_res): |
|
|
|
print("on_sanity_check_end") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_begin()) |
|
|
|
def on_train_begin(trainer): |
|
|
|
print("on_train_begin") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_end()) |
|
|
|
def on_train_end(trainer): |
|
|
|
print("on_train_end") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_epoch_begin()) |
|
|
|
def on_train_epoch_begin(trainer): |
|
|
|
if trainer.cur_epoch_idx >= 1: |
|
|
|
# 触发 on_exception; |
|
|
|
raise Exception |
|
|
|
print("on_train_epoch_begin") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_epoch_end()) |
|
|
|
def on_train_epoch_end(trainer): |
|
|
|
print("on_train_epoch_end") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_fetch_data_begin()) |
|
|
|
def on_fetch_data_begin(trainer): |
|
|
|
print("on_fetch_data_begin") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_fetch_data_end()) |
|
|
|
def on_fetch_data_end(trainer): |
|
|
|
print("on_fetch_data_end") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_batch_begin()) |
|
|
|
def on_train_batch_begin(trainer, batch, indices=None): |
|
|
|
print("on_train_batch_begin") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_train_batch_end()) |
|
|
|
def on_train_batch_end(trainer): |
|
|
|
print("on_train_batch_end") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_exception()) |
|
|
|
def on_exception(trainer, exception): |
|
|
|
print("on_exception") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_before_backward()) |
|
|
|
def on_before_backward(trainer, outputs): |
|
|
|
print("on_before_backward") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_after_backward()) |
|
|
|
def on_after_backward(trainer): |
|
|
|
print("on_after_backward") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_before_optimizers_step()) |
|
|
|
def on_before_optimizers_step(trainer, optimizers): |
|
|
|
print("on_before_optimizers_step") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_after_optimizers_step()) |
|
|
|
def on_after_optimizers_step(trainer, optimizers): |
|
|
|
print("on_after_optimizers_step") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_before_zero_grad()) |
|
|
|
def on_before_zero_grad(trainer, optimizers): |
|
|
|
print("on_before_zero_grad") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_after_zero_grad()) |
|
|
|
def on_after_zero_grad(trainer, optimizers): |
|
|
|
print("on_after_zero_grad") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_evaluate_begin()) |
|
|
|
def on_evaluate_begin(trainer): |
|
|
|
print("on_evaluate_begin") |
|
|
|
|
|
|
|
@Trainer.on(Events.on_evaluate_end()) |
|
|
|
def on_evaluate_end(trainer, results): |
|
|
|
print("on_evaluate_end") |
|
|
|
|
|
|
|
with pytest.raises(Exception): |
|
|
|
with Capturing() as output: |
|
|
|
trainer = Trainer( |
|
|
|
model=model_and_optimizers.model, |
|
|
|
driver=driver, |
|
|
|
device=device, |
|
|
|
optimizers=model_and_optimizers.optimizers, |
|
|
|
train_dataloader=model_and_optimizers.train_dataloader, |
|
|
|
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, |
|
|
|
input_mapping=model_and_optimizers.input_mapping, |
|
|
|
output_mapping=model_and_optimizers.output_mapping, |
|
|
|
metrics=model_and_optimizers.metrics, |
|
|
|
|
|
|
|
n_epochs=n_epochs, |
|
|
|
) |
|
|
|
|
|
|
|
trainer.run() |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
for name, member in Events.__members__.items(): |
|
|
|
assert member.value in output[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) |
|
|
|
@pytest.mark.torch |
|
|
|
@magic_argv_env_context |
|
|
|
def test_trainer_event_trigger_3( |
|
|
|
model_and_optimizers: TrainerParameters, |
|
|
|
driver, |
|
|
|
device, |
|
|
|
n_epochs=2, |
|
|
|
): |
|
|
|
|
|
|
|
@Trainer.on(Events.on_after_trainer_initialized) |
|
|
|
def on_after_trainer_initialized(trainer, driver): |
|
|
|
print("on_after_trainer_initialized") |
|
|
@@ -224,3 +344,7 @@ def test_trainer_event_trigger_2( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|