diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index a083e42c..762b3114 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -37,7 +37,7 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import DataParallel import paddle.distributed.fleet as fleet - import paddle.distributed as dist + import paddle.distributed as paddledist from paddle.io import BatchSampler from paddle.optimizer import Optimizer from paddle.fluid.reader import _DatasetKind @@ -185,8 +185,8 @@ class PaddleFleetDriver(PaddleDriver): if sorted(pre_gpus) != sorted(self.parallel_device): raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not" "allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.") - self.world_size = dist.get_world_size() - self.global_rank = dist.get_rank() + self.world_size = paddledist.get_world_size() + self.global_rank = paddledist.get_rank() if not self.outside_fleet: # self.model.to(self.model_device) @@ -197,12 +197,12 @@ class PaddleFleetDriver(PaddleDriver): # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; # TODO 不用.to会怎么样? self._pids = [] - dist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32")) + paddledist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32")) # TODO LOCAL_WORLD_SIZE local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None if local_world_size is None: local_world_size = paddle.to_tensor(self.local_rank, dtype="int32") - dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX) + paddledist.all_reduce(local_world_size, op=paddledist.ReduceOp.MAX) local_world_size = local_world_size.item() + 1 node_rank = self.global_rank // local_world_size @@ -232,11 +232,11 @@ class PaddleFleetDriver(PaddleDriver): 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 根据 paddle 设置的环境变量来获得各种属性 """ - self.world_size = dist.get_world_size() - self.global_rank = dist.get_rank() + self.world_size = paddledist.get_world_size() + self.global_rank = paddledist.get_rank() def barrier(self): - dist.barrier() + paddledist.barrier() def configure_fleet(self): if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):