Browse Source

PaddleFleetDriver添加all_gather和broadcast_object函数

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
fcc45dbf4a
2 changed files with 42 additions and 15 deletions
  1. +6
    -8
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +36
    -7
      tests/core/drivers/paddle_driver/test_fleet.py

+ 6
- 8
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -11,6 +11,7 @@ from .utils import (
replace_sampler,
replace_batch_sampler,
)
from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import (
@@ -451,12 +452,12 @@ class PaddleFleetDriver(PaddleDriver):
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
"""
return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
device = self.data_device
# 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误
device = get_device_from_visible(device)
return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group)

def all_gather(self, obj, group) -> List:
def all_gather(self, obj, group=None) -> List:
"""
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
pickle 进行序列化,接收到之后再反序列化。
@@ -479,7 +480,4 @@ class PaddleFleetDriver(PaddleDriver):
:param group:
:return:
"""
return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
return [obj]
return fastnlp_paddle_all_gather(obj, group=group)

+ 36
- 7
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.samplers import (
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm

import paddle
import paddle.distributed as dist
@@ -112,6 +112,35 @@ class TestFleetDriverFunction:
self.driver.get_local_rank()
dist.barrier()

@magic_argv_env_context
def test_all_gather(self):
"""
测试 all_gather 函数
详细的测试在 test_dist_utils.py 中完成
"""
obj = {
"rank": self.driver.global_rank
}
obj_list = self.driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list):
assert res["rank"] == i

@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_broadcast_object(self, src_rank):
"""
测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成
"""
if self.driver.global_rank == src_rank:
obj = {
"rank": self.driver.global_rank
}
else:
obj = None
res = self.driver.broadcast_object(obj, src=src_rank)
assert res["rank"] == src_rank

############################################################################
#
# 测试 set_dist_repro_dataloader 函数
@@ -543,11 +572,11 @@ class TestSaveLoad:
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
rank_zero_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
rank_zero_rm(path + ".pdiparams")
rank_zero_rm(path + ".pdiparams.info")
rank_zero_rm(path + ".pdmodel")

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -658,7 +687,7 @@ class TestSaveLoad:
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -769,4 +798,4 @@ class TestSaveLoad:
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas

finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

Loading…
Cancel
Save