| @@ -47,9 +47,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
| raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
| if device >= _could_use_device_num: | if device >= _could_use_device_num: | ||||
| raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
| if device != -1: | |||||
| device = f"gpu:{device}" | |||||
| else: | |||||
| if device == -1: | |||||
| device = list(range(_could_use_device_num)) | device = list(range(_could_use_device_num)) | ||||
| elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
| device = list(set(device)) | device = list(set(device)) | ||||
| @@ -61,9 +59,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
| elif each >= _could_use_device_num: | elif each >= _could_use_device_num: | ||||
| raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | ||||
| " the available gpu number.") | " the available gpu number.") | ||||
| if len(device) == 1: | |||||
| # 传入了 [1] 这样的,视为单卡。 | |||||
| device = device[0] | |||||
| elif device is not None and not isinstance(device, str): | elif device is not None and not isinstance(device, str): | ||||
| raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
| @@ -82,6 +77,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
| logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | ||||
| "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | ||||
| "choose `paddle` driver.") | "choose `paddle` driver.") | ||||
| return PaddleFleetDriver(model, device, **kwargs) | |||||
| return PaddleFleetDriver(model, [device], **kwargs) | |||||
| else: | else: | ||||
| return PaddleFleetDriver(model, device, **kwargs) | return PaddleFleetDriver(model, device, **kwargs) | ||||
| @@ -19,7 +19,12 @@ from fastNLP.envs import ( | |||||
| rank_zero_call, | rank_zero_call, | ||||
| ) | ) | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
| from fastNLP.core.samplers import ( | |||||
| ReproducibleBatchSampler, | |||||
| ReproducibleSampler, | |||||
| RandomBatchSampler, | |||||
| RandomSampler, | |||||
| ) | |||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| @@ -29,7 +34,7 @@ if _NEED_IMPORT_PADDLE: | |||||
| Dataset, | Dataset, | ||||
| Sampler, | Sampler, | ||||
| BatchSampler, | BatchSampler, | ||||
| RandomSampler, | |||||
| RandomSampler as PaddleRandomSampler, | |||||
| ) | ) | ||||
| from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
| @@ -333,6 +338,9 @@ class PaddleDriver(Driver): | |||||
| sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
| elif isinstance(dataloader_args.sampler, ReproducibleSampler): | elif isinstance(dataloader_args.sampler, ReproducibleSampler): | ||||
| sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
| elif isinstance(dataloader_args.sampler, PaddleRandomSampler): | |||||
| sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
| elif self.is_distributed(): | elif self.is_distributed(): | ||||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
| "`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
| @@ -464,7 +472,7 @@ class PaddleDriver(Driver): | |||||
| res.sampler = dataloader.batch_sampler.sampler | res.sampler = dataloader.batch_sampler.sampler | ||||
| if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | ||||
| res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
| elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
| elif isinstance(dataloader.batch_sampler.sampler, PaddleRandomSampler): | |||||
| res.shuffle = True | res.shuffle = True | ||||
| else: | else: | ||||
| res.shuffle = False | res.shuffle = False | ||||
| @@ -474,7 +482,7 @@ class PaddleDriver(Driver): | |||||
| res.sampler = batch_sampler.sampler | res.sampler = batch_sampler.sampler | ||||
| if hasattr(batch_sampler.sampler, "shuffle"): | if hasattr(batch_sampler.sampler, "shuffle"): | ||||
| res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
| elif isinstance(batch_sampler.sampler, RandomSampler): | |||||
| elif isinstance(batch_sampler.sampler, PaddleRandomSampler): | |||||
| res.shuffle = True | res.shuffle = True | ||||
| else: | else: | ||||
| res.shuffle = False | res.shuffle = False | ||||
| @@ -19,7 +19,7 @@ def test_incorrect_driver(): | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "device", | "device", | ||||
| ["cpu", "gpu:0", 0, [1]] | |||||
| ["cpu", "gpu:0", 0] | |||||
| ) | ) | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "driver", | "driver", | ||||
| @@ -27,7 +27,7 @@ def test_incorrect_driver(): | |||||
| ) | ) | ||||
| def test_get_single_device(driver, device): | def test_get_single_device(driver, device): | ||||
| """ | """ | ||||
| 测试正常情况下初始化PaddleSingleDriver的情况 | |||||
| 测试正常情况下初始化 PaddleSingleDriver 的情况 | |||||
| """ | """ | ||||
| model = PaddleNormalModel_Classification_1(2, 100) | model = PaddleNormalModel_Classification_1(2, 100) | ||||
| @@ -36,7 +36,7 @@ def test_get_single_device(driver, device): | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "device", | "device", | ||||
| [0, 1] | |||||
| [0, 1, [1]] | |||||
| ) | ) | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "driver", | "driver", | ||||
| @@ -45,7 +45,7 @@ def test_get_single_device(driver, device): | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_get_fleet_2(driver, device): | def test_get_fleet_2(driver, device): | ||||
| """ | """ | ||||
| 测试 fleet 多卡的初始化情况 | |||||
| 测试 fleet 多卡的初始化情况,但传入了单个 gpu | |||||
| """ | """ | ||||
| model = PaddleNormalModel_Classification_1(64, 10) | model = PaddleNormalModel_Classification_1(64, 10) | ||||
| @@ -34,7 +34,7 @@ class TestPaddleDriverFunctions: | |||||
| def test_check_single_optimizer_legality(self): | def test_check_single_optimizer_legality(self): | ||||
| """ | """ | ||||
| 测试传入单个optimizer时的表现 | |||||
| 测试传入单个 optimizer 时的表现 | |||||
| """ | """ | ||||
| optimizer = paddle.optimizer.Adam( | optimizer = paddle.optimizer.Adam( | ||||
| parameters=self.driver.model.parameters(), | parameters=self.driver.model.parameters(), | ||||
| @@ -50,7 +50,7 @@ class TestPaddleDriverFunctions: | |||||
| def test_check_optimizers_legality(self): | def test_check_optimizers_legality(self): | ||||
| """ | """ | ||||
| 测试传入optimizer list的表现 | |||||
| 测试传入 optimizer list 的表现 | |||||
| """ | """ | ||||
| optimizers = [ | optimizers = [ | ||||
| paddle.optimizer.Adam( | paddle.optimizer.Adam( | ||||
| @@ -70,13 +70,13 @@ class TestPaddleDriverFunctions: | |||||
| def test_check_dataloader_legality_in_train(self): | def test_check_dataloader_legality_in_train(self): | ||||
| """ | """ | ||||
| 测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||||
| 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
| """ | """ | ||||
| dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
| dataloader = DataLoader(PaddleNormalDataset()) | |||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
| # batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
| dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
| dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
| @@ -90,29 +90,29 @@ class TestPaddleDriverFunctions: | |||||
| def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
| """ | """ | ||||
| 测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||||
| 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
| """ | """ | ||||
| # 此时传入的应该是dict | # 此时传入的应该是dict | ||||
| dataloader = { | dataloader = { | ||||
| "train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
| "test":paddle.io.DataLoader(PaddleNormalDataset()) | |||||
| "train": DataLoader(PaddleNormalDataset()), | |||||
| "test":DataLoader(PaddleNormalDataset()) | |||||
| } | } | ||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
| # batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
| dataloader = { | dataloader = { | ||||
| "train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
| "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
| "train": DataLoader(PaddleNormalDataset()), | |||||
| "test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
| } | } | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
| # 传入的不是dict,应该报错 | |||||
| dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
| # 传入的不是 dict ,应该报错 | |||||
| dataloader = DataLoader(PaddleNormalDataset()) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
| # 创建torch的dataloader | |||||
| # 创建 torch 的 dataloader | |||||
| train_loader = torch.utils.data.DataLoader( | train_loader = torch.utils.data.DataLoader( | ||||
| TorchNormalDataset(), | TorchNormalDataset(), | ||||
| batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
| @@ -127,7 +127,7 @@ class TestPaddleDriverFunctions: | |||||
| def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
| """ | """ | ||||
| 测试tensor_to_numeric函数 | |||||
| 测试 tensor_to_numeric 函数 | |||||
| """ | """ | ||||
| # 单个张量 | # 单个张量 | ||||
| tensor = paddle.to_tensor(3) | tensor = paddle.to_tensor(3) | ||||
| @@ -180,7 +180,7 @@ class TestPaddleDriverFunctions: | |||||
| def test_set_model_mode(self): | def test_set_model_mode(self): | ||||
| """ | """ | ||||
| 测试set_model_mode函数 | |||||
| 测试 set_model_mode 函数 | |||||
| """ | """ | ||||
| self.driver.set_model_mode("train") | self.driver.set_model_mode("train") | ||||
| assert self.driver.model.training | assert self.driver.model.training | ||||
| @@ -192,14 +192,14 @@ class TestPaddleDriverFunctions: | |||||
| def test_move_model_to_device_cpu(self): | def test_move_model_to_device_cpu(self): | ||||
| """ | """ | ||||
| 测试move_model_to_device函数 | |||||
| 测试 move_model_to_device 函数 | |||||
| """ | """ | ||||
| PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | ||||
| assert self.driver.model.linear1.weight.place.is_cpu_place() | assert self.driver.model.linear1.weight.place.is_cpu_place() | ||||
| def test_move_model_to_device_gpu(self): | def test_move_model_to_device_gpu(self): | ||||
| """ | """ | ||||
| 测试move_model_to_device函数 | |||||
| 测试 move_model_to_device 函数 | |||||
| """ | """ | ||||
| PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | ||||
| assert self.driver.model.linear1.weight.place.is_gpu_place() | assert self.driver.model.linear1.weight.place.is_gpu_place() | ||||
| @@ -207,7 +207,7 @@ class TestPaddleDriverFunctions: | |||||
| def test_worker_init_function(self): | def test_worker_init_function(self): | ||||
| """ | """ | ||||
| 测试worker_init_function | |||||
| 测试 worker_init_function | |||||
| """ | """ | ||||
| # 先确保不影响运行 | # 先确保不影响运行 | ||||
| # TODO:正确性 | # TODO:正确性 | ||||
| @@ -215,7 +215,7 @@ class TestPaddleDriverFunctions: | |||||
| def test_set_deterministic_dataloader(self): | def test_set_deterministic_dataloader(self): | ||||
| """ | """ | ||||
| 测试set_deterministic_dataloader | |||||
| 测试 set_deterministic_dataloader | |||||
| """ | """ | ||||
| # 先确保不影响运行 | # 先确保不影响运行 | ||||
| # TODO:正确性 | # TODO:正确性 | ||||
| @@ -224,7 +224,7 @@ class TestPaddleDriverFunctions: | |||||
| def test_set_sampler_epoch(self): | def test_set_sampler_epoch(self): | ||||
| """ | """ | ||||
| 测试set_sampler_epoch | |||||
| 测试 set_sampler_epoch | |||||
| """ | """ | ||||
| # 先确保不影响运行 | # 先确保不影响运行 | ||||
| # TODO:正确性 | # TODO:正确性 | ||||
| @@ -336,7 +336,7 @@ class TestSingleDeviceFunction: | |||||
| def test_move_data_to_device(self): | def test_move_data_to_device(self): | ||||
| """ | """ | ||||
| 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||||
| 这个函数仅调用了 paddle_move_data_to_device ,测试例在 tests/core/utils/test_paddle_utils.py 中 | |||||
| 就不重复测试了 | 就不重复测试了 | ||||
| """ | """ | ||||
| self.driver.move_data_to_device(paddle.rand((32, 64))) | self.driver.move_data_to_device(paddle.rand((32, 64))) | ||||
| @@ -490,9 +490,6 @@ class TestSetDistReproDataloader: | |||||
| else: | else: | ||||
| sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
| # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 | |||||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
| # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | ||||
| left_idxes = set() | left_idxes = set() | ||||
| if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | ||||
| @@ -510,7 +507,6 @@ class TestSetDistReproDataloader: | |||||
| new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
| else: | else: | ||||
| batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
| num_consumed_samples = num_consumed_batches * batch_size | |||||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
| # 重新构造 dataloader | # 重新构造 dataloader | ||||
| batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) | batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) | ||||
| @@ -0,0 +1,103 @@ | |||||
| import os | |||||
| import pytest | |||||
| os.environ["FASTNLP_BACKEND"] = "torch" | |||||
| from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | |||||
| from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||||
| from fastNLP.envs import get_gpu_count | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| from tests.helpers.utils import magic_argv_env_context | |||||
| import torch | |||||
| def test_incorrect_driver(): | |||||
| model = TorchNormalModel_Classification_1(2, 100) | |||||
| with pytest.raises(ValueError): | |||||
| driver = initialize_torch_driver("paddle", 0, model) | |||||
| @pytest.mark.parametrize( | |||||
| "device", | |||||
| ["cpu", "cuda:0", 0, torch.device("cuda:0")] | |||||
| ) | |||||
| @pytest.mark.parametrize( | |||||
| "driver", | |||||
| ["torch"] | |||||
| ) | |||||
| def test_get_single_device(driver, device): | |||||
| """ | |||||
| 测试正常情况下初始化TorchSingleDriver的情况 | |||||
| """ | |||||
| model = TorchNormalModel_Classification_1(2, 100) | |||||
| driver = initialize_torch_driver(driver, device, model) | |||||
| assert isinstance(driver, TorchSingleDriver) | |||||
| @pytest.mark.parametrize( | |||||
| "device", | |||||
| [0, 1] | |||||
| ) | |||||
| @pytest.mark.parametrize( | |||||
| "driver", | |||||
| ["torch_ddp"] | |||||
| ) | |||||
| @magic_argv_env_context | |||||
| def test_get_ddp_2(driver, device): | |||||
| """ | |||||
| 测试 ddp 多卡的初始化情况,但传入了单个 gpu | |||||
| """ | |||||
| model = TorchNormalModel_Classification_1(64, 10) | |||||
| driver = initialize_torch_driver(driver, device, model) | |||||
| assert isinstance(driver, TorchDDPDriver) | |||||
| @pytest.mark.parametrize( | |||||
| "device", | |||||
| [[0, 2, 3], -1] | |||||
| ) | |||||
| @pytest.mark.parametrize( | |||||
| "driver", | |||||
| ["torch", "torch_ddp"] | |||||
| ) | |||||
| @magic_argv_env_context | |||||
| def test_get_ddp(driver, device): | |||||
| """ | |||||
| 测试 ddp 多卡的初始化情况 | |||||
| """ | |||||
| model = TorchNormalModel_Classification_1(64, 10) | |||||
| driver = initialize_torch_driver(driver, device, model) | |||||
| assert isinstance(driver, TorchDDPDriver) | |||||
| @pytest.mark.parametrize( | |||||
| ("driver", "device"), | |||||
| [("torch_ddp", "cpu")] | |||||
| ) | |||||
| @magic_argv_env_context | |||||
| def test_get_ddp_cpu(driver, device): | |||||
| """ | |||||
| 测试试图在 cpu 上初始化分布式训练的情况 | |||||
| """ | |||||
| model = TorchNormalModel_Classification_1(64, 10) | |||||
| with pytest.raises(ValueError): | |||||
| driver = initialize_torch_driver(driver, device, model) | |||||
| @pytest.mark.parametrize( | |||||
| "device", | |||||
| [-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] | |||||
| ) | |||||
| @pytest.mark.parametrize( | |||||
| "driver", | |||||
| ["torch", "torch_ddp"] | |||||
| ) | |||||
| @magic_argv_env_context | |||||
| def test_device_out_of_range(driver, device): | |||||
| """ | |||||
| 测试传入的device超过范围的情况 | |||||
| """ | |||||
| model = TorchNormalModel_Classification_1(2, 100) | |||||
| with pytest.raises(ValueError): | |||||
| driver = initialize_torch_driver(driver, device, model) | |||||