From 4781178a5ad655a492b6c1f3b8d8949b64fe0cad Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 8 Apr 2022 22:36:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/controllers/test_trainer_w_evaluator_torch.py | 4 +--- tests/core/drivers/torch_driver/test_torch_replace_sampler.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index ef8c8a93..8944e45d 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -132,7 +132,6 @@ def test_trainer_torch_with_evaluator( @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) -@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]]) @pytest.mark.parametrize("fp16", [True, False]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context @@ -140,12 +139,11 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( model_and_optimizers: TrainerParameters, driver, device, - callbacks, fp16, accumulation_steps, n_epochs=6, ): - + callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index f6e483b6..edb98190 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -77,3 +77,4 @@ def check_replace_sampler(driver): +