|
|
@@ -6,7 +6,7 @@ from pathlib import Path |
|
|
|
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver |
|
|
|
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler |
|
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDatset |
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset |
|
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset |
|
|
|
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 |
|
|
|
from fastNLP.core import rank_zero_rm |
|
|
@@ -17,7 +17,7 @@ import paddle |
|
|
|
|
|
|
|
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): |
|
|
|
""" |
|
|
|
建立一个 batch_samper 为 RandomBatchSampler 的 dataloader |
|
|
|
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader |
|
|
|
""" |
|
|
|
if shuffle: |
|
|
|
sampler = torch.utils.data.RandomSampler(dataset) |
|
|
@@ -38,7 +38,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): |
|
|
|
|
|
|
|
def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0): |
|
|
|
""" |
|
|
|
建立一个 samper 为 RandomSampler 的 dataloader |
|
|
|
建立一个 sampler 为 RandomSampler 的 dataloader |
|
|
|
""" |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
@@ -531,7 +531,7 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): |
|
|
|
|
|
|
|
@pytest.fixture |
|
|
|
def prepare_test_save_load(): |
|
|
|
dataset = TorchArgMaxDatset(10, 40) |
|
|
|
dataset = TorchArgMaxDataset(10, 40) |
|
|
|
dataloader = DataLoader(dataset, batch_size=4) |
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
return driver1, driver2, dataloader |
|
|
@@ -566,7 +566,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): |
|
|
|
|
|
|
|
try: |
|
|
|
path = "model.ckp" |
|
|
|
dataset = TorchArgMaxDatset(10, 40) |
|
|
|
dataset = TorchArgMaxDataset(10, 40) |
|
|
|
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) |
|
|
|
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") |
|
|
|
|
|
|
@@ -636,7 +636,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): |
|
|
|
path = "model.ckp" |
|
|
|
|
|
|
|
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") |
|
|
|
dataset = TorchArgMaxDatset(10, 40) |
|
|
|
dataset = TorchArgMaxDataset(10, 40) |
|
|
|
dataloader = dataloader_with_randomsampler(dataset, 4, True, False) |
|
|
|
num_consumed_batches = 2 |
|
|
|
|
|
|
|