diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 48299bf4..11799515 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.utils.data import DataLoader, BatchSampler -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, BatchSampler - -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) device = [torch.device(i) for i in device] @@ -73,107 +74,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= ############################################################################ @pytest.mark.torch +@magic_argv_env_context +def test_multi_drivers(): + """ + 测试使用了多个 TorchDDPDriver 的情况。 + """ + generate_driver(10, 10) + generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + generate_driver(20, 3, device=[0,1,2]) + assert False + dist.barrier() + + if dist.is_initialized(): + dist.destroy_process_group() + +@pytest.mark.torch +@pytest.mark.torchtemp class TestDDPDriverFunction: """ 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 """ - @classmethod - def setup_class(cls): - cls.driver = generate_driver(10, 10) - @magic_argv_env_context - def test_multi_drivers(self): + def test_simple_functions(self): """ - 测试使用了多个 TorchDDPDriver 的情况。 + 简单测试多个函数 """ - - driver2 = generate_driver(20, 10) - - with pytest.raises(RuntimeError): - # 设备设置不同,应该报错 - driver3 = generate_driver(20, 3, device=[0,1,2]) - assert False - dist.barrier() + driver = generate_driver(10, 10) - @magic_argv_env_context - def test_move_data_to_device(self): """ - 这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 - 就不重复测试了 + 测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 + tests/core/utils/test_torch_utils.py中,就不重复测试了 """ - self.driver.move_data_to_device(torch.rand((32, 64))) - + driver.move_data_to_device(torch.rand((32, 64))) dist.barrier() - @magic_argv_env_context - def test_is_distributed(self): """ 测试 is_distributed 函数 """ - assert self.driver.is_distributed() == True + assert driver.is_distributed() == True dist.barrier() - @magic_argv_env_context - def test_get_no_sync_context(self): """ 测试 get_no_sync_context 函数 """ - res = self.driver.get_model_no_sync_context() + res = driver.get_model_no_sync_context() dist.barrier() - @magic_argv_env_context - def test_is_global_zero(self): """ 测试 is_global_zero 函数 """ - self.driver.is_global_zero() + driver.is_global_zero() dist.barrier() - @magic_argv_env_context - def test_unwrap_model(self): """ 测试 unwrap_model 函数 """ - self.driver.unwrap_model() + driver.unwrap_model() dist.barrier() - @magic_argv_env_context - def test_get_local_rank(self): """ 测试 get_local_rank 函数 """ - self.driver.get_local_rank() + driver.get_local_rank() dist.barrier() - @magic_argv_env_context - def test_all_gather(self): """ 测试 all_gather 函数 详细的测试在 test_dist_utils.py 中完成 """ obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } - obj_list = self.driver.all_gather(obj, group=None) + obj_list = driver.all_gather(obj, group=None) for i, res in enumerate(obj_list): assert res["rank"] == i - @magic_argv_env_context - @pytest.mark.parametrize("src_rank", ([0, 1])) - def test_broadcast_object(self, src_rank): """ 测试 broadcast_object 函数 详细的函数在 test_dist_utils.py 中完成 """ - if self.driver.global_rank == src_rank: + if driver.global_rank == 0: obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } else: obj = None - res = self.driver.broadcast_object(obj, src=src_rank) - assert res["rank"] == src_rank + res = driver.broadcast_object(obj, src=0) + assert res["rank"] == 0 + + if dist.is_initialized(): + dist.destroy_process_group() ############################################################################ # @@ -182,12 +176,12 @@ class TestDDPDriverFunction: ############################################################################ @pytest.mark.torch +@pytest.mark.torchtemp class TestSetDistReproDataloader: @classmethod def setup_class(cls): cls.device = [0, 1] - cls.driver = generate_driver(10, 10, device=cls.device) def setup_method(self): self.dataset = TorchNormalDataset(40) @@ -204,17 +198,20 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler is batch_sampler self.check_distributed_sampler(replaced_loader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -223,9 +220,10 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -234,9 +232,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` @@ -251,15 +251,17 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) with pytest.raises(RuntimeError): # 应当抛出 RuntimeError - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context - # @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False])) def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ @@ -268,21 +270,24 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler.batch_size == 4 self.check_distributed_sampler(dataloader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -292,12 +297,13 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -307,9 +313,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.drop_last == False self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -318,11 +326,14 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 此时直接返回原来的 dataloader,不做任何处理。 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert replaced_loader is dataloader dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -337,12 +348,13 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader( dataset=self.dataset, batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) ) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) @@ -351,6 +363,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -361,8 +375,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) @@ -372,6 +387,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -381,8 +398,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -392,6 +410,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -407,8 +427,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -418,6 +439,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -427,8 +450,9 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -439,6 +463,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -448,8 +474,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -459,6 +486,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() def check_distributed_sampler(self, sampler): """ @@ -469,7 +498,7 @@ class TestSetDistReproDataloader: if not isinstance(sampler, UnrepeatedSampler): assert sampler.pad == True - def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): """ 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ @@ -501,8 +530,8 @@ class TestSetDistReproDataloader: drop_last=False, ) new_loader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) new_loader.batch_sampler.load_state_dict(sampler_states) @@ -512,8 +541,8 @@ class TestSetDistReproDataloader: # 重新构造 dataloader new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) new_loader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) for idx, batch in enumerate(new_loader): @@ -534,11 +563,6 @@ class TestSaveLoad: 测试多卡情况下 save 和 load 相关函数的表现 """ - @classmethod - def setup_class(cls): - # 不在这里 setup 的话会报错 - cls.driver = generate_driver(10, 10) - def setup_method(self): self.dataset = TorchArgMaxDataset(10, 20) @@ -552,26 +576,26 @@ class TestSaveLoad: path = "model" dataloader = DataLoader(self.dataset, batch_size=2) - self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) - self.driver1.save_model(path, only_state_dict) + driver1.save_model(path, only_state_dict) # 同步 dist.barrier() - self.driver2.load_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) for idx, batch in enumerate(dataloader): - batch = self.driver1.move_data_to_device(batch) - res1 = self.driver1.model( + batch = driver1.move_data_to_device(batch) + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -580,6 +604,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -593,7 +620,7 @@ class TestSaveLoad: path = "model.ckp" num_replicas = len(device) - self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_bucketedbatchsampler( self.dataset, @@ -603,8 +630,8 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -623,7 +650,7 @@ class TestSaveLoad: # 保存状态 sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) # 加载 # 更改 batch_size dataloader = dataloader_with_bucketedbatchsampler( @@ -634,11 +661,11 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -652,7 +679,7 @@ class TestSaveLoad: # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -664,16 +691,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -686,6 +713,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -700,13 +730,13 @@ class TestSaveLoad: num_replicas = len(device) - self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) - self.driver2 = generate_driver(10, 10, device=device, fp16=False) + driver1 = generate_driver(10, 10, device=device, fp16=fp16) + driver2 = generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -726,18 +756,18 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 @@ -753,7 +783,7 @@ class TestSaveLoad: assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -765,16 +795,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -786,4 +816,7 @@ class TestSaveLoad: assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: - rank_zero_rm(path) \ No newline at end of file + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index f62ccd0c..8992867e 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -2,12 +2,12 @@ import pytest from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver -from fastNLP.envs import get_gpu_count from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.utils import magic_argv_env_context - -import torch - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist @pytest.mark.torch def test_incorrect_driver(): @@ -55,6 +55,9 @@ def test_get_ddp_2(driver, device): driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchDDPDriver) + dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @pytest.mark.torch @@ -76,6 +79,9 @@ def test_get_ddp(driver, device): driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchDDPDriver) + dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @pytest.mark.torch @@ -83,7 +89,6 @@ def test_get_ddp(driver, device): ("driver", "device"), [("torch_ddp", "cpu")] ) -@magic_argv_env_context def test_get_ddp_cpu(driver, device): """ 测试试图在 cpu 上初始化分布式训练的情况 @@ -102,7 +107,6 @@ def test_get_ddp_cpu(driver, device): "driver", ["torch", "torch_ddp"] ) -@magic_argv_env_context def test_device_out_of_range(driver, device): """ 测试传入的device超过范围的情况