diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index ad471acd..76d1f793 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,5 +1,6 @@ import pytest import os +from pathlib import Path os.environ["FASTNLP_BACKEND"] = "paddle" from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver @@ -33,20 +34,6 @@ def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, out return driver -@magic_argv_env_context -def test_multi_drivers(): - """ - 测试使用了多个 PaddleFleetDriver 的情况。 - """ - driver1 = generate_driver(10, 10) - driver2 = generate_driver(20, 10) - - with pytest.raises(RuntimeError): - # 设备设置不同,应该报错 - driver3 = generate_driver(20, 3, device=[0,2]) - - dist.barrier() - ############################################################################ # # 测试 PaddleFleetDriver 的一些函数 @@ -62,6 +49,19 @@ class TestFleetDriverFunction: def setup_class(cls): cls.driver = generate_driver(10, 10) + @magic_argv_env_context + def test_multi_drivers(self): + """ + 测试使用了多个 PaddleFleetDriver 的情况。 + """ + driver2 = generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + driver3 = generate_driver(20, 3, device=[0,2]) + + dist.barrier() + @magic_argv_env_context def test_move_data_to_device(self): """ @@ -494,9 +494,14 @@ class TestSaveLoad: """ 测试多卡情况下 save 和 load 相关函数的表现 """ + + @classmethod + def setup_class(cls): + # 不在这里 setup 的话会报错 + cls.driver = generate_driver(10, 10) + def setup_method(self): self.dataset = PaddleRandomMaxDataset(20, 10) - self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -506,7 +511,9 @@ class TestSaveLoad: """ try: path = "model" + dataloader = DataLoader(self.dataset, batch_size=2) + self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) if only_state_dict: self.driver1.save_model(path, only_state_dict) @@ -545,20 +552,30 @@ class TestSaveLoad: @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) - def test_save_and_load_with_randombatchsampler(self, only_state_dict, fp16): - return + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 """ try: path = "model.ckp" + num_replicas = len(device) - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - dataset = PaddleRandomMaxDataset(40, 10) + self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + generate_driver(10, 10, device=device, fp16=False) dataloader = DataLoader( - dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) + dataset=self.dataset, + batch_sampler=BucketedBatchSampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + ) + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True ) num_consumed_batches = 2 @@ -570,19 +587,32 @@ class TestSaveLoad: already_seen_x_set.update(batch["x"]) already_seen_y_set.update(batch["y"]) + # 同步 + dist.barrier() + + # 保存状态 sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = DataLoader( - dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) + dataset=self.dataset, + batch_sampler=BucketedBatchSampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + ) + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -590,13 +620,13 @@ class TestSaveLoad: # 2. 检查 batch_sampler 是否被正确地加载和替换 assert not (replaced_loader is dataloader) assert replaced_loader.batch_sampler is dataloader.batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) - assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] - assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -608,22 +638,33 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = driver1.model.evaluate_step(**batch) - res2 = driver2.model.evaluate_step(**batch) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) assert paddle.equal_all(res1["pred"], res2["pred"]) - assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) - assert len(left_x_batches | already_seen_x_set) == len(dataset) - assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) - assert len(left_y_batches | already_seen_y_set) == len(dataset) + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: synchronize_safe_rm(path) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) - def test_save_and_load_with_randomsampler(self, only_state_dict, fp16): - return + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 """ @@ -631,12 +672,19 @@ class TestSaveLoad: try: path = "model.ckp" - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - dataset = PaddleRandomMaxDataset(40, 10) - batch_sampler = BatchSampler(dataset=dataset, batch_size=4) - batch_sampler.sampler = RandomSampler(dataset, True) + num_replicas = len(device) + + self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) + self.driver2 = generate_driver(10, 10, device=device, fp16=False) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) + batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True + ) dataloader = DataLoader( - dataset, + self.dataset, batch_sampler=batch_sampler ) num_consumed_batches = 2 @@ -649,22 +697,30 @@ class TestSaveLoad: already_seen_x_set.update(batch["x"]) already_seen_y_set.update(batch["y"]) + # 同步 + dist.barrier() + + # 保存状态 sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) - + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size - batch_sampler = BatchSampler(dataset=dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(dataset, True) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True + ) dataloader = DataLoader( - dataset, + self.dataset, batch_sampler=batch_sampler ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 @@ -675,12 +731,12 @@ class TestSaveLoad: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] - assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -692,13 +748,25 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = driver1.model.evaluate_step(**batch) - res2 = driver2.model.evaluate_step(**batch) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) assert paddle.equal_all(res1["pred"], res2["pred"]) - assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) - assert len(left_x_batches | already_seen_x_set) == len(dataset) - assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) - assert len(left_y_batches | already_seen_y_set) == len(dataset) + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + finally: - synchronize_safe_rm(path) + synchronize_safe_rm(path) \ No newline at end of file