diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index f149855f..5ee946c4 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -76,7 +76,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " "still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " "choose `torch` driver.") - return TorchDDPDriver(model, device, **kwargs) + return TorchDDPDriver(model, [device], **kwargs) else: return TorchDDPDriver(model, device, **kwargs) elif driver == "fairscale": diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 8e37f550..172a3cf0 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -218,6 +218,8 @@ class TorchDriver(Driver): # 2. 保存模型的状态; if should_save_model: model = self.unwrap_model() + if not os.path.exists(folder): + os.mkdir(folder) if only_state_dict: model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; @@ -401,7 +403,17 @@ class TorchDriver(Driver): res.sampler = dataloader.batch_sampler.sampler if hasattr(dataloader.batch_sampler.sampler, "shuffle"): res.shuffle = dataloader.batch_sampler.sampler.shuffle - elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): + elif isinstance(dataloader.batch_sampler.sampler, TorchRandomSampler): + res.shuffle = True + else: + res.shuffle = False + # RandomBatchSampler 的情况 + elif hasattr(dataloader.batch_sampler, "batch_sampler"): + batch_sampler = dataloader.batch_sampler.batch_sampler + res.sampler = batch_sampler.sampler + if hasattr(batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(batch_sampler.sampler, TorchRandomSampler): res.shuffle = True else: res.shuffle = False diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py new file mode 100644 index 00000000..4290d02c --- /dev/null +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -0,0 +1,697 @@ +import os +os.environ["FASTNLP_BACKEND"] = "torch" +import pytest +from pathlib import Path + +from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDatset +from tests.helpers.datasets.paddle_data import PaddleNormalDataset +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from fastNLP.core import rank_zero_rm + +import torch +from torch.utils.data import DataLoader, BatchSampler +import paddle + +def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): + """ + 建立一个 batch_samper 为 RandomBatchSampler 的 dataloader + """ + if shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler( + BatchSampler( + sampler, batch_size=batch_size, drop_last=drop_last + ), + batch_size=batch_size, + drop_last=drop_last, + ), + ) + + return dataloader + +def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0): + """ + 建立一个 samper 为 RandomSampler 的 dataloader + """ + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, shuffle, seed=seed), + drop_last=drop_last, + batch_size=batch_size + ) + return dataloader + +############################################################################ +# +# 测试基类 TorchDrvier 中的一些简单函数 +# +############################################################################ + +class TestTorchDriverFunctions: + """ + 使用 TorchSingleDriver 测试基类的函数 + """ + + @classmethod + def setup_class(self): + model = TorchNormalModel_Classification_1(10, 32) + self.driver = TorchSingleDriver(model, device="cpu") + + def test_check_single_optimizer_legality(self): + """ + 测试传入单个 optimizer 时的表现 + """ + optimizer = torch.optim.Adam( + params=self.driver.model.parameters(), + lr=0.01 + ) + + self.driver.set_optimizers(optimizer) + + optimizer = paddle.optimizer.Adam( + parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), + learning_rate=0.01, + ) + # 传入 torch 的 optimize r时,应该报错 ValueError + with pytest.raises(ValueError): + self.driver.set_optimizers(optimizer) + + def test_check_optimizers_legality(self): + """ + 测试传入 optimizer list 的表现 + """ + optimizers = [ + torch.optim.Adam( + params=self.driver.model.parameters(), + lr=0.01 + ) for i in range(10) + ] + + self.driver.set_optimizers(optimizers) + + optimizers += [ + paddle.optimizer.Adam( + parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), + learning_rate=0.01, + ) + ] + + with pytest.raises(ValueError): + self.driver.set_optimizers(optimizers) + + def test_check_dataloader_legality_in_train(self): + """ + 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 + """ + dataloader = DataLoader(TorchNormalDataset()) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + + # 创建 paddle 的 dataloader + dataloader = paddle.io.DataLoader( + PaddleNormalDataset(), + batch_size=32, shuffle=True + ) + with pytest.raises(ValueError): + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + + def test_check_dataloader_legality_in_test(self): + """ + 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 + """ + # 此时传入的应该是dict + dataloader = { + "train": DataLoader(TorchNormalDataset()), + "test": DataLoader(TorchNormalDataset()) + } + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + + # 传入的不是 dict,应该报错 + dataloader = DataLoader(TorchNormalDataset()) + with pytest.raises(ValueError): + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + + # 创建 paddle 的 dataloader + train_loader = paddle.io.DataLoader( + PaddleNormalDataset(), + batch_size=32, shuffle=True + ) + test_loader = paddle.io.DataLoader( + PaddleNormalDataset(), + batch_size=32, shuffle=True + ) + dataloader = {"train": train_loader, "test": test_loader} + with pytest.raises(ValueError): + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + + def test_tensor_to_numeric(self): + """ + 测试 tensor_to_numeric 函数 + """ + # 单个张量 + tensor = torch.tensor(3) + res = TorchSingleDriver.tensor_to_numeric(tensor) + assert res == 3 + + tensor = torch.rand((3, 4)) + res = TorchSingleDriver.tensor_to_numeric(tensor) + assert res == tensor.tolist() + + # 张量list + tensor_list = [torch.rand((6, 4, 2)) for i in range(10)] + res = TorchSingleDriver.tensor_to_numeric(tensor_list) + assert isinstance(res, list) + tensor_list = [t.tolist() for t in tensor_list] + assert res == tensor_list + + # 张量tuple + tensor_tuple = tuple([torch.rand((6, 4, 2)) for i in range(10)]) + res = TorchSingleDriver.tensor_to_numeric(tensor_tuple) + assert isinstance(res, tuple) + tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) + assert res == tensor_tuple + + # 张量dict + tensor_dict = { + "tensor": torch.rand((3, 4)), + "list": [torch.rand((6, 4, 2)) for i in range(10)], + "dict":{ + "list": [torch.rand((6, 4, 2)) for i in range(10)], + "tensor": torch.rand((3, 4)) + }, + "int": 2, + "string": "test string" + } + + res = TorchSingleDriver.tensor_to_numeric(tensor_dict) + assert isinstance(res, dict) + assert res["tensor"] == tensor_dict["tensor"].tolist() + assert isinstance(res["list"], list) + for r, d in zip(res["list"], tensor_dict["list"]): + assert r == d.tolist() + assert isinstance(res["int"], int) + assert isinstance(res["string"], str) + assert isinstance(res["dict"], dict) + assert isinstance(res["dict"]["list"], list) + for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): + assert r == d.tolist() + assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() + + def test_set_model_mode(self): + """ + 测试set_model_mode函数 + """ + self.driver.set_model_mode("train") + assert self.driver.model.training + self.driver.set_model_mode("eval") + assert not self.driver.model.training + # 应该报错 + with pytest.raises(AssertionError): + self.driver.set_model_mode("test") + + def test_move_model_to_device_cpu(self): + """ + 测试move_model_to_device函数 + """ + TorchSingleDriver.move_model_to_device(self.driver.model, "cpu") + assert self.driver.model.linear1.weight.device.type == "cpu" + + def test_move_model_to_device_gpu(self): + """ + 测试move_model_to_device函数 + """ + TorchSingleDriver.move_model_to_device(self.driver.model, "cuda") + assert self.driver.model.linear1.weight.device.type == "cuda" + assert self.driver.model.linear1.weight.device.index == 0 + + def test_worker_init_function(self): + """ + 测试worker_init_function + """ + # 先确保不影响运行 + # TODO:正确性 + TorchSingleDriver.worker_init_function(0) + + def test_set_deterministic_dataloader(self): + """ + 测试set_deterministic_dataloader + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = DataLoader(TorchNormalDataset()) + self.driver.set_deterministic_dataloader(dataloader) + + def test_set_sampler_epoch(self): + """ + 测试set_sampler_epoch + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = DataLoader(TorchNormalDataset()) + self.driver.set_sampler_epoch(dataloader, 0) + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args(self, batch_size, shuffle, drop_last): + """ + 测试正常情况下 get_dataloader_args 的表现 + """ + dataloader = DataLoader( + TorchNormalDataset(), + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) + res = TorchSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, TorchNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + if shuffle: + assert isinstance(res.sampler, torch.utils.data.RandomSampler) + else: + assert isinstance(res.sampler, torch.utils.data.SequentialSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 batch_sampler 后 get_dataloader_args 的表现 + """ + dataset = TorchNormalDataset() + dataloader = dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last) + res = TorchSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, TorchNormalDataset) + assert isinstance(res.batch_sampler, RandomBatchSampler) + if shuffle: + assert isinstance(res.sampler, torch.utils.data.RandomSampler) + else: + assert isinstance(res.sampler, torch.utils.data.SequentialSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 sampler 后 get_dataloader_args 的表现 + """ + dataset = TorchNormalDataset() + dataloader = dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last) + res = TorchSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, TorchNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + assert isinstance(res.sampler, RandomSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + +############################################################################ +# +# 测试 TorchSingleDrvier 中的一些简单函数 +# +############################################################################ + +class TestSingleDeviceFunction: + """ + 测试其它函数的测试例 + """ + + @classmethod + def setup_class(cls): + model = TorchNormalModel_Classification_1(10, 784) + cls.driver = TorchSingleDriver(model, device="cpu") + + def test_unwrap_model(self): + """ + 测试能否运行 + """ + res = self.driver.unwrap_model() + assert res is self.driver.model + + def test_is_distributed(self): + assert self.driver.is_distributed() == False + + def test_move_data_to_device(self): + """ + 这个函数仅调用了 torch_move_data_to_device ,测试例在 tests/core/utils/test_torch_utils.py 中 + 就不重复测试了 + """ + self.driver.move_data_to_device(torch.rand((32, 64))) + + +############################################################################ +# +# 测试 set_dist_repro_dataloader 函数 +# +############################################################################ + +class TestSetDistReproDataloader: + """ + 专门测试 set_dist_repro_dataloader 函数的类 + """ + def setup_method(self): + self.dataset = TorchNormalDataset(20) + model = TorchNormalModel_Classification_1(10, 32) + self.driver = TorchSingleDriver(model, device="cpu") + + def test_with_reproducible_false(self): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 + 当dist为字符串时,此时应该返回原来的 dataloader + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert replaced_loader is dataloader + + @pytest.mark.parametrize("shuffle", [True, False]) + def test_with_reproducible_true(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 + 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) + + assert not (replaced_loader is dataloader) + if shuffle: + # 此时会替换 sampler + assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + else: + # 此时会替换 batch_sampler + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler + 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) + dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler is dist + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 + 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) + dist = RandomSampler(self.dataset, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.sampler is dist + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dataloader_reproducible_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = dataloader_with_randombatchsampler(self.dataset, 4, shuffle, False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = dataloader_with_randomsampler(self.dataset, 2, shuffle, False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + """ + 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + num_consumed_batches = 2 + already_seen_idx = set() + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: + break + already_seen_idx.update(batch) + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + + # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新改造 dataloader + new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) + new_loader.batch_sampler.load_state_dict(sampler_states) + else: + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新构造 dataloader + new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) + assert len(left_idxes | already_seen_idx) == len(self.dataset) + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ + +def generate_random_driver(features, labels, fp16=False, device="cpu"): + """ + 生成driver + """ + model = TorchNormalModel_Classification_1(labels, features) + opt = torch.optim.Adam(params=model.parameters(), lr=0.01) + driver = TorchSingleDriver(model, device=device, fp16=fp16) + driver.set_optimizers(opt) + driver.setup() + + return driver + +@pytest.fixture +def prepare_test_save_load(): + dataset = TorchArgMaxDatset(10, 40) + dataloader = DataLoader(dataset, batch_size=4) + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + return driver1, driver2, dataloader + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_model(prepare_test_save_load, only_state_dict): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + driver1, driver2, dataloader = prepare_test_save_load + + driver1.save_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) + + for batch in dataloader: + batch = driver1.move_data_to_device(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + + assert torch.equal(res1["preds"], res2["preds"]) + finally: + rank_zero_rm(path) + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + dataset = TorchArgMaxDatset(10, 40) + dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) + driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") + + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + # 加载 + # 更改 batch_size + + dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 + + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + batch = driver2.move_data_to_device(batch) + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + rank_zero_rm(path) + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randomsampler(only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") + dataset = TorchArgMaxDatset(10, 40) + dataloader = dataloader_with_randomsampler(dataset, 4, True, False) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + + # 加载 + # 更改 batch_size + dataloader = dataloader_with_randomsampler(dataset, 2, True, False) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + batch = driver2.move_data_to_device(batch) + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + rank_zero_rm(path) diff --git a/tests/core/drivers/torch_driver/test_utils.py b/tests/core/drivers/torch_driver/test_utils.py index c1e604c9..8f0172e0 100644 --- a/tests/core/drivers/torch_driver/test_utils.py +++ b/tests/core/drivers/torch_driver/test_utils.py @@ -1,35 +1,36 @@ -from torch.utils.data.sampler import SequentialSampler, RandomSampler - -from fastNLP.core.samplers.sampler import ReproduceSampler -from tests.helpers.datasets.normal_data import NormalIterator - - -class TestReproduceSampler: - - def test_sequentialsampler(self): - normal_iterator = NormalIterator(num_of_data=20) - sequential_sampler = SequentialSampler(normal_iterator) - - reproduce_sampler = ReproduceSampler(sequential_sampler) - # iter_seq_sampler = iter(sequential_sampler) - # for each in iter_seq_sampler: - # print(each) - iter_reproduce_sampler = iter(reproduce_sampler) - forward_step = 3 - for _ in range(forward_step): - next(iter_reproduce_sampler) - state = reproduce_sampler.save_state() - assert state["current_batch_idx"] == forward_step - - new_repro_sampler = ReproduceSampler(sequential_sampler) - assert new_repro_sampler.save_state()["current_batch_idx"] == 0 - - new_repro_sampler.load_state(state) - iter_new_repro_sampler = iter(new_repro_sampler) - new_index_list = [] - for each in iter_new_repro_sampler: - new_index_list.append(each) - assert new_index_list == list(range(3, 20)) - - - +import os +import pytest +os.environ["FASTNLP_BACKEND"] = "torch" + +from fastNLP.core.drivers.torch_driver.utils import ( + replace_batch_sampler, + replace_sampler, +) +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from torch.utils.data import DataLoader, BatchSampler + +from tests.helpers.datasets.torch_data import TorchNormalDataset + +def test_replace_batch_sampler(): + dataset = TorchNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + + replaced_loader = replace_batch_sampler(dataloader, batch_sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.dataset, TorchNormalDataset) + assert len(replaced_loader.dataset) == len(dataset) + assert replaced_loader.batch_sampler.batch_size == 16 + +def test_replace_sampler(): + dataset = TorchNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + sampler = RandomSampler(dataset) + + replaced_loader = replace_sampler(dataloader, sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) \ No newline at end of file