Browse Source

修改paddle.distributed的import名

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
be24572b11
1 changed files with 8 additions and 8 deletions
  1. +8
    -8
      fastNLP/core/drivers/paddle_driver/fleet.py

+ 8
- 8
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -37,7 +37,7 @@ if _NEED_IMPORT_PADDLE:
import paddle
from paddle import DataParallel
import paddle.distributed.fleet as fleet
import paddle.distributed as dist
import paddle.distributed as paddledist
from paddle.io import BatchSampler
from paddle.optimizer import Optimizer
from paddle.fluid.reader import _DatasetKind
@@ -185,8 +185,8 @@ class PaddleFleetDriver(PaddleDriver):
if sorted(pre_gpus) != sorted(self.parallel_device):
raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not"
"allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.")
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
self.world_size = paddledist.get_world_size()
self.global_rank = paddledist.get_rank()

if not self.outside_fleet:
# self.model.to(self.model_device)
@@ -197,12 +197,12 @@ class PaddleFleetDriver(PaddleDriver):
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
# TODO 不用.to会怎么样?
self._pids = []
dist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32"))
paddledist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32"))
# TODO LOCAL_WORLD_SIZE
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
if local_world_size is None:
local_world_size = paddle.to_tensor(self.local_rank, dtype="int32")
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
paddledist.all_reduce(local_world_size, op=paddledist.ReduceOp.MAX)
local_world_size = local_world_size.item() + 1

node_rank = self.global_rank // local_world_size
@@ -232,11 +232,11 @@ class PaddleFleetDriver(PaddleDriver):
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要
根据 paddle 设置的环境变量来获得各种属性
"""
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
self.world_size = paddledist.get_world_size()
self.global_rank = paddledist.get_rank()

def barrier(self):
dist.barrier()
paddledist.barrier()

def configure_fleet(self):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):


Loading…
Cancel
Save