diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 52739f53..34c80888 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -527,7 +527,7 @@ class TestSaveLoad: @classmethod def setup_class(cls): # 不在这里 setup 的话会报错 - cls.driver = generate_driver(10, 10) + cls.driver = generate_driver(10, 10, device=[0,1]) def setup_method(self): self.dataset = PaddleRandomMaxDataset(20, 10) @@ -633,7 +633,7 @@ class TestSaveLoad: batch_sampler=BucketedBatchSampler( self.dataset, length=[10 for i in range(len(self.dataset))], - batch_size=4, + batch_size=2, ) ) dataloader.batch_sampler.set_distributed(