|
|
@@ -545,22 +545,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): |
|
|
|
|
|
|
|
return driver |
|
|
|
|
|
|
|
@pytest.fixture |
|
|
|
def prepare_test_save_load(): |
|
|
|
dataset = TorchArgMaxDataset(10, 40) |
|
|
|
dataloader = DataLoader(dataset, batch_size=4) |
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
return driver1, driver2, dataloader |
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
def test_save_and_load_model(prepare_test_save_load, only_state_dict): |
|
|
|
def test_save_and_load_model(only_state_dict): |
|
|
|
""" |
|
|
|
测试 save_model 和 load_model 函数 |
|
|
|
""" |
|
|
|
try: |
|
|
|
path = "model" |
|
|
|
driver1, driver2, dataloader = prepare_test_save_load |
|
|
|
dataset = TorchArgMaxDataset(10, 40) |
|
|
|
dataloader = DataLoader(dataset, batch_size=4) |
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
|
|
|
|
driver1.save_model(path, only_state_dict) |
|
|
|
driver2.load_model(path, only_state_dict) |
|
|
|