diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py index 123e0098..a9fc74cb 100644 --- a/tests/trainers/test_image_portrait_enhancement_trainer.py +++ b/tests/trainers/test_image_portrait_enhancement_trainer.py @@ -61,6 +61,7 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): train_dataset=self.dataset_train, eval_dataset=self.dataset_val, device='gpu', + max_epochs=1, work_dir=self.tmp_dir) trainer = build_trainer( @@ -81,7 +82,7 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): train_dataset=self.dataset_train, eval_dataset=self.dataset_val, device='gpu', - max_epochs=2, + max_epochs=1, work_dir=self.tmp_dir) trainer = build_trainer(