Browse Source

修复测试

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
4781178a5a
2 changed files with 2 additions and 3 deletions
  1. +1
    -3
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  2. +1
    -0
      tests/core/drivers/torch_driver/test_torch_replace_sampler.py

+ 1
- 3
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -132,7 +132,6 @@ def test_trainer_torch_with_evaluator(


@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]])
@pytest.mark.parametrize("fp16", [True, False])
@pytest.mark.parametrize("accumulation_steps", [1, 3])
@magic_argv_env_context
@@ -140,12 +139,11 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
model_and_optimizers: TrainerParameters,
driver,
device,
callbacks,
fp16,
accumulation_steps,
n_epochs=6,
):
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,


+ 1
- 0
tests/core/drivers/torch_driver/test_torch_replace_sampler.py View File

@@ -77,3 +77,4 @@ def check_replace_sampler(driver):





Loading…
Cancel
Save