Browse Source

Fix dist judgement when torch.distributed.is_available is always False

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10976015
master^2
yuze.zyz wenmeng.zwm 2 years ago
parent
commit
bf97dd7501
2 changed files with 10 additions and 6 deletions
  1. +3
    -3
      modelscope/trainers/trainer.py
  2. +7
    -3
      modelscope/utils/torch_utils.py

+ 3
- 3
modelscope/trainers/trainer.py View File

@@ -37,7 +37,7 @@ from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import build_from_cfg
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
init_dist, is_master,
init_dist, is_dist, is_master,
set_random_seed)
from .base import BaseTrainer
from .builder import TRAINERS
@@ -236,7 +236,7 @@ class EpochBasedTrainer(BaseTrainer):
device_name: The final device name.
"""
device_name = device if device is not None else 'gpu'
if dist.is_initialized():
if is_dist():
local_rank = get_local_rank()
device_name = f'cuda:{local_rank}'

@@ -603,7 +603,7 @@ class EpochBasedTrainer(BaseTrainer):
for key in match_keys:
value = train_outputs.get(key, None)
if value is not None:
if dist.is_available() and dist.is_initialized():
if is_dist():
value = value.data.clone().to('cuda')
dist.all_reduce(value.div_(dist.get_world_size()))
log_vars.update({key: value.item()})


+ 7
- 3
modelscope/utils/torch_utils.py View File

@@ -106,7 +106,7 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:


def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
if is_dist():
try:
from megatron import mpu
assert mpu.model_parallel_is_initialized()
@@ -125,8 +125,12 @@ def get_local_rank():
return int(os.environ.get('LOCAL_RANK', 0))


def is_dist():
return dist.is_available() and dist.is_initialized()


def is_master():
return dist.get_rank() == 0 if dist.is_initialized() else True
return dist.get_rank() == 0 if is_dist() else True


def master_only(func: Callable) -> Callable:
@@ -142,7 +146,7 @@ def master_only(func: Callable) -> Callable:
def make_tmp_dir():
"""Make sure each rank has the same temporary directory on the distributed mode.
"""
if not dist.is_initialized():
if not is_dist():
return tempfile.mkdtemp()

tmpdir = None


Loading…
Cancel
Save