|
|
@@ -4,23 +4,17 @@ from typing import List, Any |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from fastNLP.core.metrics.backend import Backend |
|
|
|
from fastNLP.core.utils.paddle_utils import paddle_to, _convert_data_device |
|
|
|
from fastNLP.core.utils.paddle_utils import paddle_to, _convert_data_device, is_in_paddle_dist |
|
|
|
from fastNLP.core.metrics.utils import AggregateMethodError |
|
|
|
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES |
|
|
|
|
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
|
import paddle |
|
|
|
import paddle.distributed as dist |
|
|
|
from paddle.fluid.dygraph import parallel_helper |
|
|
|
|
|
|
|
|
|
|
|
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: |
|
|
|
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] |
|
|
|
dist.all_gather(gathered_result, result, group) |
|
|
|
return gathered_result |
|
|
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
|
|
class PaddleBackend(Backend): |
|
|
|
def __init__(self): |
|
|
@@ -80,6 +74,13 @@ class PaddleBackend(Backend): |
|
|
|
else: |
|
|
|
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def is_distributed() -> bool: |
|
|
|
""" |
|
|
|
:return: |
|
|
|
""" |
|
|
|
return is_in_paddle_dist() |
|
|
|
|
|
|
|
def move_tensor_to_device(self, tensor, device): |
|
|
|
device = _convert_data_device(device) |
|
|
|
return paddle_to(tensor, device) |
|
|
|