@@ -14,11 +14,13 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle.distributed as dist | import paddle.distributed as dist | ||||
from paddle.fluid.dygraph import parallel_helper | from paddle.fluid.dygraph import parallel_helper | ||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | ||||
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | ||||
dist.all_gather(gathered_result, result, group) | dist.all_gather(gathered_result, result, group) | ||||
return gathered_result | return gathered_result | ||||
class PaddleBackend(Backend): | class PaddleBackend(Backend): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -124,4 +126,3 @@ class PaddleBackend(Backend): | |||||
# TODO 如果在这里处理的话,会不会在别的地方引起bug? | # TODO 如果在这里处理的话,会不会在别的地方引起bug? | ||||
device = get_device_from_visible(device) | device = get_device_from_visible(device) | ||||
return paddle_to(tensor, device) | return paddle_to(tensor, device) | ||||
@@ -11,7 +11,6 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
import torch.nn.functional as F | |||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | ||||
@@ -33,7 +32,7 @@ class TorchBackend(Backend): | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
if method is None: | if method is None: | ||||
raise AggregateMethodError(should_have_aggregate_method=True) | raise AggregateMethodError(should_have_aggregate_method=True) | ||||
tensor = fastnlp_torch_all_gather(tensor) | |||||
tensor = self.all_gather_object(tensor) | |||||
if isinstance(tensor[0], torch.Tensor): | if isinstance(tensor[0], torch.Tensor): | ||||
tensor = torch.stack(tensor) | tensor = torch.stack(tensor) | ||||
# 第一步, aggregate结果 | # 第一步, aggregate结果 | ||||
@@ -0,0 +1,59 @@ | |||||
import os | |||||
import pytest | |||||
import paddle | |||||
import paddle.distributed | |||||
import paddle.distributed.fleet.base.role_maker as role_maker | |||||
import paddle.distributed.fleet as fleet | |||||
from fastNLP.core.metrics import Accuracy | |||||
from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher | |||||
############################################################################ | |||||
# | |||||
# 测试 单机单卡情况下的Accuracy | |||||
# | |||||
############################################################################ | |||||
def test_accuracy_single(): | |||||
pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669, | |||||
0.46514302], | |||||
[-0.85775983, -2.18273783, -1.07505429, -1.45561373, 0.40011844, | |||||
1.02202022], | |||||
[-0.39487389, 0.65682763, -0.62424040, 0.53692561, -0.28390560, | |||||
-0.02559055], | |||||
[-0.22586937, -0.07676325, -0.95977223, 0.36395910, -0.91758579, | |||||
-0.83857095], | |||||
[0.25136873, 2.49652624, 1.06251311, 1.60194016, 1.01451588, | |||||
0.08403367], | |||||
[0.10844281, 1.19017303, -0.11378096, 1.12686944, -0.08654942, | |||||
0.48605862], | |||||
[1.27320433, -1.13902378, 1.47072780, -0.98665696, -0.42589864, | |||||
0.64618838], | |||||
[0.83809763, -0.05356205, 0.03042423, -0.28371972, 0.81611472, | |||||
-0.45802942], | |||||
[0.38535264, 0.09721313, 2.27187467, 0.32045507, -0.20711982, | |||||
-0.13550705], | |||||
[-0.75228405, -1.34161997, 1.08697927, 0.33218071, -1.19470012, | |||||
2.58735061]]) | |||||
tg = paddle.to_tensor([1, 2, 1, 3, 5, 4, 4, 2, 1, 5]) | |||||
acc_metric = Accuracy() | |||||
acc_metric.update(pred, tg) | |||||
result = acc_metric.get_metric() | |||||
true_result = {'acc': 0.3} | |||||
assert true_result == result | |||||
############################################################################ | |||||
# | |||||
# 测试 单机多卡情况下的Accuracy | |||||
# | |||||
############################################################################ | |||||
def test_accuracy_ddp(): | |||||
launcher = FleetLauncher(devices=[0, 1]) | |||||
launcher.launch() | |||||
role = role_maker.PaddleCloudRoleMaker(is_collective=True) | |||||
fleet.init(role) | |||||
if fleet.is_server(): | |||||
pass | |||||
elif fleet.is_worker(): | |||||
print(os.getenv("PADDLE_TRAINER_ID")) | |||||