|
@@ -535,7 +535,7 @@ class TestSetDistReproDataloder: |
|
|
# |
|
|
# |
|
|
############################################################################ |
|
|
############################################################################ |
|
|
|
|
|
|
|
|
def generate_random_driver(features, labels, fp16, device="cpu"): |
|
|
|
|
|
|
|
|
def generate_random_driver(features, labels, fp16=False, device="cpu"): |
|
|
""" |
|
|
""" |
|
|
生成driver |
|
|
生成driver |
|
|
""" |
|
|
""" |
|
@@ -549,8 +549,8 @@ def generate_random_driver(features, labels, fp16, device="cpu"): |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
@pytest.fixture |
|
|
def prepare_test_save_load(): |
|
|
def prepare_test_save_load(): |
|
|
dataset = PaddleRandomMaxDataset(320, 10) |
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=32) |
|
|
|
|
|
|
|
|
dataset = PaddleRandomMaxDataset(40, 10) |
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=4) |
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
return driver1, driver2, dataloader |
|
|
return driver1, driver2, dataloader |
|
|
|
|
|
|
|
|