|
@@ -1,11 +1,10 @@ |
|
|
import os |
|
|
import os |
|
|
from numpy import isin |
|
|
|
|
|
os.environ["FASTNLP_BACKEND"] = "paddle" |
|
|
os.environ["FASTNLP_BACKEND"] = "paddle" |
|
|
import pytest |
|
|
import pytest |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver |
|
|
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver |
|
|
from fastNLP.core.samplers.reproducible_sampler import RandomSampler |
|
|
|
|
|
from fastNLP.core.samplers import RandomBatchSampler |
|
|
|
|
|
|
|
|
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler |
|
|
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 |
|
|
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 |
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset |
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset |
|
@@ -42,27 +41,101 @@ def prepare_test_save_load(): |
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
return driver1, driver2, dataloader |
|
|
return driver1, driver2, dataloader |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("reproducible", [True, False]) |
|
|
|
|
|
@pytest.mark.parametrize("only_state_dict", [True, False]) |
|
|
|
|
|
def test_save_and_load(prepare_test_save_load, reproducible, only_state_dict): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
|
|
def test_save_and_load_with_randombatchsampler(only_state_dict): |
|
|
""" |
|
|
""" |
|
|
测试save和load函数 |
|
|
|
|
|
TODO optimizer的state_dict为空,暂时不测试 |
|
|
|
|
|
|
|
|
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
path = "model.ckp" |
|
|
path = "model.ckp" |
|
|
driver1, driver2, dataloader = prepare_test_save_load |
|
|
|
|
|
dataloader = driver1.set_dist_repro_dataloader(dataloader, "dist", reproducible) |
|
|
|
|
|
|
|
|
|
|
|
driver1.save(path, {}, dataloader, only_state_dict, should_save_model=True) |
|
|
|
|
|
driver2.load(path, dataloader, only_state_dict, should_load_model=True) |
|
|
|
|
|
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
|
|
dataset = PaddleRandomMaxDataset(80, 10) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
dataset=dataset, |
|
|
|
|
|
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# TODO 断点重训完善后在这里迭代几次 |
|
|
|
|
|
|
|
|
|
|
|
sampler_states = dataloader.batch_sampler.state_dict() |
|
|
|
|
|
if only_state_dict: |
|
|
|
|
|
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) |
|
|
|
|
|
else: |
|
|
|
|
|
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) |
|
|
|
|
|
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) |
|
|
|
|
|
|
|
|
|
|
|
# 1. 检查 optimizer 的状态 |
|
|
|
|
|
# TODO optimizer 的 state_dict 总是为空 |
|
|
|
|
|
|
|
|
|
|
|
# 2. 检查 batch_sampler 是否被正确地加载和替换 |
|
|
|
|
|
replaced_loader = states["dataloader"] |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] |
|
|
|
|
|
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] |
|
|
|
|
|
|
|
|
|
|
|
# 3. 检查 model 的参数是否被正确加载 |
|
|
|
|
|
for batch in dataloader: |
|
|
|
|
|
res1 = driver1.validate_step(batch) |
|
|
|
|
|
res2 = driver2.validate_step(batch) |
|
|
|
|
|
|
|
|
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
|
|
|
|
|
|
|
|
# 4. 检查 batch_idx |
|
|
|
|
|
# TODO |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm(path) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
|
|
def test_save_and_load_with_randomsampler(only_state_dict): |
|
|
|
|
|
""" |
|
|
|
|
|
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
path = "model.ckp" |
|
|
|
|
|
|
|
|
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
|
|
dataset = PaddleRandomMaxDataset(80, 10) |
|
|
|
|
|
batch_sampler = BatchSampler(dataset=dataset, batch_size=2) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(dataset, True) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
dataset, |
|
|
|
|
|
batch_sampler=batch_sampler |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# TODO 断点重训完善后在这里迭代几次 |
|
|
|
|
|
|
|
|
|
|
|
sampler_states = dataloader.batch_sampler.sampler.state_dict() |
|
|
|
|
|
if only_state_dict: |
|
|
|
|
|
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) |
|
|
|
|
|
else: |
|
|
|
|
|
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) |
|
|
|
|
|
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) |
|
|
|
|
|
|
|
|
|
|
|
# 1. 检查 optimizer 的状态 |
|
|
|
|
|
# TODO optimizer 的 state_dict 总是为空 |
|
|
|
|
|
|
|
|
|
|
|
# 2. 检查 sampler 是否被正确地加载和替换 |
|
|
|
|
|
replaced_loader = states["dataloader"] |
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"] |
|
|
|
|
|
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] |
|
|
|
|
|
|
|
|
|
|
|
# 3. 检查 model 的参数是否被正确加载 |
|
|
for batch in dataloader: |
|
|
for batch in dataloader: |
|
|
res1 = driver1.validate_step(batch) |
|
|
res1 = driver1.validate_step(batch) |
|
|
res2 = driver2.validate_step(batch) |
|
|
res2 = driver2.validate_step(batch) |
|
|
|
|
|
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
|
|
|
|
|
|
|
|
# 4. 检查 batch_idx |
|
|
|
|
|
# TODO |
|
|
finally: |
|
|
finally: |
|
|
synchronize_safe_rm(path) |
|
|
synchronize_safe_rm(path) |
|
|
|
|
|
|
|
@@ -144,24 +217,138 @@ class TestSingleDeviceFunction: |
|
|
""" |
|
|
""" |
|
|
self.driver.move_data_to_device(paddle.rand((32, 64))) |
|
|
self.driver.move_data_to_device(paddle.rand((32, 64))) |
|
|
|
|
|
|
|
|
# @pytest.mark.parametrize( |
|
|
|
|
|
# "dist_sampler", [ |
|
|
|
|
|
# "dist", |
|
|
|
|
|
# RandomBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), |
|
|
|
|
|
# RandomSampler(PaddleRandomMaxDataset(320, 10)) |
|
|
|
|
|
# ] |
|
|
|
|
|
# ) |
|
|
|
|
|
# @pytest.mark.parametrize( |
|
|
|
|
|
# "reproducible", |
|
|
|
|
|
# [True, False] |
|
|
|
|
|
# ) |
|
|
|
|
|
# def test_set_dist_repro_dataloader(self, dist_sampler, reproducible): |
|
|
|
|
|
# """ |
|
|
|
|
|
# 测试set_dist_repro_dataloader函数 |
|
|
|
|
|
# """ |
|
|
|
|
|
# dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) |
|
|
|
|
|
|
|
|
|
|
|
# res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSetDistReproDataloder: |
|
|
|
|
|
""" |
|
|
|
|
|
专门测试 set_dist_repro_dataloader 函数的类 |
|
|
|
|
|
""" |
|
|
|
|
|
def setup_method(self): |
|
|
|
|
|
self.dataset = PaddleNormalDataset(20) |
|
|
|
|
|
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) |
|
|
|
|
|
model = PaddleNormalModel_Classification_1(10, 32) |
|
|
|
|
|
self.driver = PaddleSingleDriver(model, device="cpu") |
|
|
|
|
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_reproducible_false(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 |
|
|
|
|
|
当dist为字符串时,此时应该返回原来的 dataloader |
|
|
|
|
|
""" |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) |
|
|
|
|
|
|
|
|
|
|
|
assert replaced_loader is self.dataloader |
|
|
|
|
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_reproducible_true(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 |
|
|
|
|
|
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler |
|
|
|
|
|
""" |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is self.dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size |
|
|
|
|
|
assert replaced_loader.drop_last == self.dataloader.drop_last |
|
|
|
|
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) |
|
|
|
|
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler |
|
|
|
|
|
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler |
|
|
|
|
|
""" |
|
|
|
|
|
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is self.dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler is dist |
|
|
|
|
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) |
|
|
|
|
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 |
|
|
|
|
|
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler |
|
|
|
|
|
""" |
|
|
|
|
|
dist = RandomSampler(self.dataset, shuffle=True) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is self.dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler is dist |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size |
|
|
|
|
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) |
|
|
|
|
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 |
|
|
|
|
|
应该返回新的 dataloader,且其余各项设置和原来相同 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
dataset=self.dataset, |
|
|
|
|
|
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) |
|
|
|
|
|
) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
|
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
|
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 |
|
|
|
|
|
应该返回新的 dataloader,且其余各项设置和原来相同 |
|
|
|
|
|
""" |
|
|
|
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, True) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
self.dataset, |
|
|
|
|
|
batch_sampler=batch_sampler |
|
|
|
|
|
) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
|
|
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
|
|
|
|
|
|
|
|
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): |
|
|
|
|
|
""" |
|
|
|
|
|
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 |
|
|
|
|
|
""" |
|
|
|
|
|
# 迭代两个 batch |
|
|
|
|
|
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 |
|
|
|
|
|
already_seen_idx = set() |
|
|
|
|
|
for idx, batch in replaced_loader: |
|
|
|
|
|
already_seen_idx.update(batch) |
|
|
|
|
|
if idx >= 1: |
|
|
|
|
|
break |
|
|
|
|
|
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): |
|
|
|
|
|
sampler_states = replaced_loader.batch_sampler.state_dict() |
|
|
|
|
|
else: |
|
|
|
|
|
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() |
|
|
|
|
|
print(sampler_states["data_idx"]) |
|
|
|
|
|
|
|
|
|
|
|
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range |
|
|
|
|
|
left_idxes = set() |
|
|
|
|
|
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): |
|
|
|
|
|
replaced_loader.batch_sampler.load_state_dict(sampler_states) |
|
|
|
|
|
else: |
|
|
|
|
|
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) |
|
|
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
|
|
left_idxes.update(batch) |
|
|
|
|
|
|
|
|
|
|
|
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) |
|
|
|
|
|
assert len(left_idxes | already_seen_idx) == len(self.dataset) |
|
|
|
|
|
|
|
|
class TestPaddleDriverFunctions: |
|
|
class TestPaddleDriverFunctions: |
|
|
""" |
|
|
""" |
|
@@ -229,7 +416,7 @@ class TestPaddleDriverFunctions: |
|
|
with pytest.raises(ValueError): |
|
|
with pytest.raises(ValueError): |
|
|
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) |
|
|
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
|
|
|
|
|
def test_check_dataloader_legacy_in_test(self): |
|
|
|
|
|
|
|
|
def test_check_dataloader_legality_in_test(self): |
|
|
""" |
|
|
""" |
|
|
测试is_train参数为False时,_check_dataloader_legality函数的表现 |
|
|
测试is_train参数为False时,_check_dataloader_legality函数的表现 |
|
|
""" |
|
|
""" |
|
@@ -372,11 +559,78 @@ class TestPaddleDriverFunctions: |
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
self.driver.set_sampler_epoch(dataloader, 0) |
|
|
self.driver.set_sampler_epoch(dataloader, 0) |
|
|
|
|
|
|
|
|
def test_get_dataloader_args(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
|
|
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): |
|
|
""" |
|
|
""" |
|
|
测试get_dataloader_args |
|
|
|
|
|
|
|
|
测试正常情况下 get_dataloader_args 的表现 |
|
|
""" |
|
|
""" |
|
|
# 先确保不影响运行 |
|
|
|
|
|
# TODO:正确性 |
|
|
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
PaddleNormalDataset(), |
|
|
|
|
|
batch_size=batch_size, |
|
|
|
|
|
shuffle=shuffle, |
|
|
|
|
|
drop_last=drop_last, |
|
|
|
|
|
) |
|
|
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
|
|
assert isinstance(res.batch_sampler, BatchSampler) |
|
|
|
|
|
if shuffle: |
|
|
|
|
|
assert isinstance(res.sampler, paddle.io.RandomSampler) |
|
|
|
|
|
else: |
|
|
|
|
|
assert isinstance(res.sampler, paddle.io.SequenceSampler) |
|
|
|
|
|
assert res.shuffle == shuffle |
|
|
|
|
|
assert res.batch_size == batch_size |
|
|
|
|
|
assert res.drop_last == drop_last |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
|
|
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): |
|
|
|
|
|
""" |
|
|
|
|
|
测试替换了 batch_sampler 后 get_dataloader_args 的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataset = PaddleNormalDataset() |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
dataset, |
|
|
|
|
|
batch_sampler=RandomBatchSampler( |
|
|
|
|
|
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), |
|
|
|
|
|
batch_size, |
|
|
|
|
|
drop_last, |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
|
|
assert isinstance(res.batch_sampler, RandomBatchSampler) |
|
|
|
|
|
if shuffle: |
|
|
|
|
|
assert isinstance(res.sampler, paddle.io.RandomSampler) |
|
|
|
|
|
else: |
|
|
|
|
|
assert isinstance(res.sampler, paddle.io.SequenceSampler) |
|
|
|
|
|
assert res.shuffle == shuffle |
|
|
|
|
|
assert res.batch_size == batch_size |
|
|
|
|
|
assert res.drop_last == drop_last |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
|
|
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): |
|
|
|
|
|
""" |
|
|
|
|
|
测试替换了 sampler 后 get_dataloader_args 的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataset = PaddleNormalDataset() |
|
|
|
|
|
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(dataset, shuffle) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
dataset, |
|
|
|
|
|
batch_sampler=batch_sampler, |
|
|
|
|
|
) |
|
|
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
|
|
assert isinstance(res.batch_sampler, BatchSampler) |
|
|
|
|
|
assert isinstance(res.sampler, RandomSampler) |
|
|
|
|
|
assert res.shuffle == shuffle |
|
|
|
|
|
assert res.batch_size == batch_size |
|
|
|
|
|
assert res.drop_last == drop_last |