diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index ef8c8a93..8944e45d 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -132,7 +132,6 @@ def test_trainer_torch_with_evaluator( @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) -@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]]) @pytest.mark.parametrize("fp16", [True, False]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context @@ -140,12 +139,11 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( model_and_optimizers: TrainerParameters, driver, device, - callbacks, fp16, accumulation_steps, n_epochs=6, ): - + callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index f6e483b6..edb98190 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -77,3 +77,4 @@ def check_replace_sampler(driver): +