@@ -1,3 +1,4 @@ | |||||
import os | |||||
from typing import Callable, Any, Union, Sequence | from typing import Callable, Any, Union, Sequence | ||||
from abc import ABC | from abc import ABC | ||||
import inspect | import inspect | ||||
@@ -126,7 +127,7 @@ class OverfitDataLoader: | |||||
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") | logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") | ||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx < self.overfit_batches or self.overfit_batches < -1: | |||||
if idx < self.overfit_batches or self.overfit_batches <= -1: | |||||
self.batches.append(batch) | self.batches.append(batch) | ||||
def __len__(self): | def __len__(self): | ||||
@@ -340,10 +340,13 @@ def test_trainer_specific_params_2( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_w_evaluator_overfit_torch( | def test_trainer_w_evaluator_overfit_torch( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | |||||
device, | |||||
overfit_batches, | overfit_batches, | ||||
num_train_batch_per_epoch | num_train_batch_per_epoch | ||||
): | ): | ||||
@@ -352,8 +355,8 @@ def test_trainer_w_evaluator_overfit_torch( | |||||
""" | """ | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver="torch", | |||||
device=0, | |||||
driver=driver, | |||||
device=device, | |||||
overfit_batches=overfit_batches, | overfit_batches=overfit_batches, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
@@ -363,17 +363,20 @@ def test_torch_wo_auto_param_call( | |||||
# 测试 accumulation_steps; | # 测试 accumulation_steps; | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) | |||||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_overfit_torch( | def test_trainer_overfit_torch( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | |||||
device, | |||||
overfit_batches, | overfit_batches, | ||||
num_train_batch_per_epoch | num_train_batch_per_epoch | ||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver="torch", | |||||
device=0, | |||||
driver=driver, | |||||
device=device, | |||||
overfit_batches=overfit_batches, | overfit_batches=overfit_batches, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||