diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index c407ab9f..bde6f37f 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -11,6 +11,7 @@ from .utils import ( replace_sampler, replace_batch_sampler, ) +from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.utils import ( @@ -451,12 +452,12 @@ class PaddleFleetDriver(PaddleDriver): :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ - return - if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 - return - return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) + device = self.data_device + # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 + device = get_device_from_visible(device) + return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group) - def all_gather(self, obj, group) -> List: + def all_gather(self, obj, group=None) -> List: """ 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 pickle 进行序列化,接收到之后再反序列化。 @@ -479,7 +480,4 @@ class PaddleFleetDriver(PaddleDriver): :param group: :return: """ - return - if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 - return [obj] return fastnlp_paddle_all_gather(obj, group=group) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 76d1f793..52739f53 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -14,7 +14,7 @@ from fastNLP.core.samplers import ( from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import synchronize_safe_rm +from fastNLP.core import rank_zero_rm import paddle import paddle.distributed as dist @@ -112,6 +112,35 @@ class TestFleetDriverFunction: self.driver.get_local_rank() dist.barrier() + @magic_argv_env_context + def test_all_gather(self): + """ + 测试 all_gather 函数 + 详细的测试在 test_dist_utils.py 中完成 + """ + obj = { + "rank": self.driver.global_rank + } + obj_list = self.driver.all_gather(obj, group=None) + for i, res in enumerate(obj_list): + assert res["rank"] == i + + @magic_argv_env_context + @pytest.mark.parametrize("src_rank", ([0, 1])) + def test_broadcast_object(self, src_rank): + """ + 测试 broadcast_object 函数 + 详细的函数在 test_dist_utils.py 中完成 + """ + if self.driver.global_rank == src_rank: + obj = { + "rank": self.driver.global_rank + } + else: + obj = None + res = self.driver.broadcast_object(obj, src=src_rank) + assert res["rank"] == src_rank + ############################################################################ # # 测试 set_dist_repro_dataloader 函数 @@ -543,11 +572,11 @@ class TestSaveLoad: assert paddle.equal_all(res1["pred"], res2["pred"]) finally: if only_state_dict: - synchronize_safe_rm(path) + rank_zero_rm(path) else: - synchronize_safe_rm(path + ".pdiparams") - synchronize_safe_rm(path + ".pdiparams.info") - synchronize_safe_rm(path + ".pdmodel") + rank_zero_rm(path + ".pdiparams") + rank_zero_rm(path + ".pdiparams.info") + rank_zero_rm(path + ".pdmodel") @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -658,7 +687,7 @@ class TestSaveLoad: assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: - synchronize_safe_rm(path) + rank_zero_rm(path) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -769,4 +798,4 @@ class TestSaveLoad: assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: - synchronize_safe_rm(path) \ No newline at end of file + rank_zero_rm(path) \ No newline at end of file