|
|
@@ -527,7 +527,7 @@ class TestSaveLoad: |
|
|
|
@classmethod |
|
|
|
def setup_class(cls): |
|
|
|
# 不在这里 setup 的话会报错 |
|
|
|
cls.driver = generate_driver(10, 10) |
|
|
|
cls.driver = generate_driver(10, 10, device=[0,1]) |
|
|
|
|
|
|
|
def setup_method(self): |
|
|
|
self.dataset = PaddleRandomMaxDataset(20, 10) |
|
|
@@ -633,7 +633,7 @@ class TestSaveLoad: |
|
|
|
batch_sampler=BucketedBatchSampler( |
|
|
|
self.dataset, |
|
|
|
length=[10 for i in range(len(self.dataset))], |
|
|
|
batch_size=4, |
|
|
|
batch_size=2, |
|
|
|
) |
|
|
|
) |
|
|
|
dataloader.batch_sampler.set_distributed( |
|
|
|