From 665d79a3ede01e6252d6ddf9d867900c14adf998 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Fri, 15 Apr 2022 20:03:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0paddle=E5=8D=95=E5=8D=A1?= =?UTF-8?q?=E7=9A=84accuracy=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../metrics/backend/paddle_backend/backend.py | 3 +- .../metrics/backend/torch_backend/backend.py | 3 +- tests/core/metrics/test_accutacy_paddle.py | 59 +++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 tests/core/metrics/test_accutacy_paddle.py diff --git a/fastNLP/core/metrics/backend/paddle_backend/backend.py b/fastNLP/core/metrics/backend/paddle_backend/backend.py index 12216d4b..7a7e7f7a 100644 --- a/fastNLP/core/metrics/backend/paddle_backend/backend.py +++ b/fastNLP/core/metrics/backend/paddle_backend/backend.py @@ -14,11 +14,13 @@ if _NEED_IMPORT_PADDLE: import paddle.distributed as dist from paddle.fluid.dygraph import parallel_helper + def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] dist.all_gather(gathered_result, result, group) return gathered_result + class PaddleBackend(Backend): def __init__(self): super().__init__() @@ -124,4 +126,3 @@ class PaddleBackend(Backend): # TODO 如果在这里处理的话,会不会在别的地方引起bug? device = get_device_from_visible(device) return paddle_to(tensor, device) - diff --git a/fastNLP/core/metrics/backend/torch_backend/backend.py b/fastNLP/core/metrics/backend/torch_backend/backend.py index 8945ab01..a602434e 100644 --- a/fastNLP/core/metrics/backend/torch_backend/backend.py +++ b/fastNLP/core/metrics/backend/torch_backend/backend.py @@ -11,7 +11,6 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe if _NEED_IMPORT_TORCH: import torch import torch.distributed as dist - import torch.nn.functional as F def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: @@ -33,7 +32,7 @@ class TorchBackend(Backend): if dist.is_initialized(): if method is None: raise AggregateMethodError(should_have_aggregate_method=True) - tensor = fastnlp_torch_all_gather(tensor) + tensor = self.all_gather_object(tensor) if isinstance(tensor[0], torch.Tensor): tensor = torch.stack(tensor) # 第一步, aggregate结果 diff --git a/tests/core/metrics/test_accutacy_paddle.py b/tests/core/metrics/test_accutacy_paddle.py new file mode 100644 index 00000000..1580d3a7 --- /dev/null +++ b/tests/core/metrics/test_accutacy_paddle.py @@ -0,0 +1,59 @@ +import os + +import pytest +import paddle +import paddle.distributed +import paddle.distributed.fleet.base.role_maker as role_maker +import paddle.distributed.fleet as fleet +from fastNLP.core.metrics import Accuracy +from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher + +############################################################################ +# +# 测试 单机单卡情况下的Accuracy +# +############################################################################ +def test_accuracy_single(): + pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669, + 0.46514302], + [-0.85775983, -2.18273783, -1.07505429, -1.45561373, 0.40011844, + 1.02202022], + [-0.39487389, 0.65682763, -0.62424040, 0.53692561, -0.28390560, + -0.02559055], + [-0.22586937, -0.07676325, -0.95977223, 0.36395910, -0.91758579, + -0.83857095], + [0.25136873, 2.49652624, 1.06251311, 1.60194016, 1.01451588, + 0.08403367], + [0.10844281, 1.19017303, -0.11378096, 1.12686944, -0.08654942, + 0.48605862], + [1.27320433, -1.13902378, 1.47072780, -0.98665696, -0.42589864, + 0.64618838], + [0.83809763, -0.05356205, 0.03042423, -0.28371972, 0.81611472, + -0.45802942], + [0.38535264, 0.09721313, 2.27187467, 0.32045507, -0.20711982, + -0.13550705], + [-0.75228405, -1.34161997, 1.08697927, 0.33218071, -1.19470012, + 2.58735061]]) + tg = paddle.to_tensor([1, 2, 1, 3, 5, 4, 4, 2, 1, 5]) + acc_metric = Accuracy() + acc_metric.update(pred, tg) + result = acc_metric.get_metric() + true_result = {'acc': 0.3} + assert true_result == result + + +############################################################################ +# +# 测试 单机多卡情况下的Accuracy +# +############################################################################ +def test_accuracy_ddp(): + launcher = FleetLauncher(devices=[0, 1]) + launcher.launch() + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_server(): + pass + elif fleet.is_worker(): + print(os.getenv("PADDLE_TRAINER_ID")) +