|
|
@@ -100,17 +100,16 @@ def model_and_optimizers(request): |
|
|
|
# 测试一下普通的情况; |
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) |
|
|
|
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) |
|
|
|
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) |
|
|
|
@magic_argv_env_context |
|
|
|
def test_trainer_torch_with_evaluator( |
|
|
|
model_and_optimizers: TrainerParameters, |
|
|
|
driver, |
|
|
|
device, |
|
|
|
callbacks, |
|
|
|
evaluate_every, |
|
|
|
n_epochs=10, |
|
|
|
): |
|
|
|
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)] |
|
|
|
trainer = Trainer( |
|
|
|
model=model_and_optimizers.model, |
|
|
|
driver=driver, |
|
|
@@ -172,7 +171,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( |
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) |
|
|
|
@magic_argv_env_context |
|
|
|
def test_trainer_validate_every( |
|
|
@@ -184,9 +183,7 @@ def test_trainer_validate_every( |
|
|
|
|
|
|
|
def validate_every(trainer): |
|
|
|
if trainer.global_forward_batches % 10 == 0: |
|
|
|
print(trainer) |
|
|
|
print("\nfastNLP test validate every.\n") |
|
|
|
print(trainer.global_forward_batches) |
|
|
|
return True |
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
|