|
|
@@ -14,7 +14,7 @@ from fastNLP.core.samplers import ( |
|
|
|
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 |
|
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset |
|
|
|
from tests.helpers.utils import magic_argv_env_context |
|
|
|
from fastNLP.core import synchronize_safe_rm |
|
|
|
from fastNLP.core import rank_zero_rm |
|
|
|
|
|
|
|
import paddle |
|
|
|
import paddle.distributed as dist |
|
|
@@ -112,6 +112,35 @@ class TestFleetDriverFunction: |
|
|
|
self.driver.get_local_rank() |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
def test_all_gather(self): |
|
|
|
""" |
|
|
|
测试 all_gather 函数 |
|
|
|
详细的测试在 test_dist_utils.py 中完成 |
|
|
|
""" |
|
|
|
obj = { |
|
|
|
"rank": self.driver.global_rank |
|
|
|
} |
|
|
|
obj_list = self.driver.all_gather(obj, group=None) |
|
|
|
for i, res in enumerate(obj_list): |
|
|
|
assert res["rank"] == i |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("src_rank", ([0, 1])) |
|
|
|
def test_broadcast_object(self, src_rank): |
|
|
|
""" |
|
|
|
测试 broadcast_object 函数 |
|
|
|
详细的函数在 test_dist_utils.py 中完成 |
|
|
|
""" |
|
|
|
if self.driver.global_rank == src_rank: |
|
|
|
obj = { |
|
|
|
"rank": self.driver.global_rank |
|
|
|
} |
|
|
|
else: |
|
|
|
obj = None |
|
|
|
res = self.driver.broadcast_object(obj, src=src_rank) |
|
|
|
assert res["rank"] == src_rank |
|
|
|
|
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试 set_dist_repro_dataloader 函数 |
|
|
@@ -543,11 +572,11 @@ class TestSaveLoad: |
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
finally: |
|
|
|
if only_state_dict: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
rank_zero_rm(path) |
|
|
|
else: |
|
|
|
synchronize_safe_rm(path + ".pdiparams") |
|
|
|
synchronize_safe_rm(path + ".pdiparams.info") |
|
|
|
synchronize_safe_rm(path + ".pdmodel") |
|
|
|
rank_zero_rm(path + ".pdiparams") |
|
|
|
rank_zero_rm(path + ".pdiparams.info") |
|
|
|
rank_zero_rm(path + ".pdmodel") |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
@@ -658,7 +687,7 @@ class TestSaveLoad: |
|
|
|
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas |
|
|
|
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas |
|
|
|
finally: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
rank_zero_rm(path) |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
@@ -769,4 +798,4 @@ class TestSaveLoad: |
|
|
|
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas |
|
|
|
|
|
|
|
finally: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
rank_zero_rm(path) |