|
|
@@ -132,7 +132,6 @@ def test_trainer_torch_with_evaluator( |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) |
|
|
|
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]]) |
|
|
|
@pytest.mark.parametrize("fp16", [True, False]) |
|
|
|
@pytest.mark.parametrize("accumulation_steps", [1, 3]) |
|
|
|
@magic_argv_env_context |
|
|
@@ -140,12 +139,11 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( |
|
|
|
model_and_optimizers: TrainerParameters, |
|
|
|
driver, |
|
|
|
device, |
|
|
|
callbacks, |
|
|
|
fp16, |
|
|
|
accumulation_steps, |
|
|
|
n_epochs=6, |
|
|
|
): |
|
|
|
|
|
|
|
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] |
|
|
|
trainer = Trainer( |
|
|
|
model=model_and_optimizers.model, |
|
|
|
driver=driver, |
|
|
|