Browse Source

实现 jittor driver 多卡训练 (#418)

tags/v1.0.0alpha
Letian Li GitHub 2 years ago
parent
commit
005b0e055e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 21 deletions
  1. +7
    -2
      fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py
  2. +64
    -16
      fastNLP/core/drivers/jittor_driver/mpi.py
  3. +5
    -3
      tests/core/controllers/test_trainer_jittor.py

+ 7
- 2
fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py View File

@@ -1,5 +1,6 @@
from typing import Union, List

from fastNLP.core.drivers.jittor_driver.mpi import JittorMPIDriver
from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -29,7 +30,11 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].")

# TODO 实现更详细的判断
if driver == "jittor":
if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]:
return JittorSingleDriver(model, device, **kwargs)
elif type(device) is int:
return JittorMPIDriver(model, device, **kwargs)
elif type(device) is list:
return JittorMPIDriver(model, device, **kwargs)
else:
raise NotImplementedError
raise NotImplementedError(f"Device={device}")

+ 64
- 16
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -2,11 +2,14 @@ import os
from typing import Optional, Union, Callable, Dict, Tuple

from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger

if _NEED_IMPORT_JITTOR:
import jittor
import jittor as jt

__all__ = [
"JittorMPIDriver",
@@ -42,7 +45,31 @@ class JittorMPIDriver(JittorDriver):
self.outside_mpi = False

def setup(self):
pass
self.__fork_with_mpi__()

def __fork_with_mpi__(self):
import sys
if jt.in_mpi:
# you can mult other process output
if jt.rank != 0:
sys.stdout = open("/dev/null", "w")
return
else:
if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡
raise NotImplementedError(f"Device={self.parallel_device}")
elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练
num_procs = 1
devices = self.parallel_device
elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定
num_procs = len(self.parallel_device)
devices = str(self.parallel_device)[1:-1]
else:
raise NotImplementedError(f"Device={self.parallel_device}")
print(sys.argv)
cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv)
print("[RUN CMD]:", cmd)
os.system(cmd)
exit(0)

def configure_mpi(self):
pass
@@ -71,25 +98,46 @@ class JittorMPIDriver(JittorDriver):
def data_device(self):
if self.outside_mpi:
return self._data_device
return self.model_device
return self.parallel_device

def step(self):
# for optimizer in self.optimizers:
# self.grad_scaler.step(optimizer)
# self.grad_scaler.update()
for optimizer in self.optimizers:
optimizer.step()

def backward(self, loss):
# self.grad_scaler.scale(loss).backward()
for optimizer in self.optimizers:
optimizer.backward(loss)

def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()

def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
pass
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
pass
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
return self.model, self.model.execute
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
pass

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
return dataloader

def is_global_zero(self):
return self.global_rank == 0
@@ -107,4 +155,4 @@ class JittorMPIDriver(JittorDriver):
pass

def is_distributed(self):
return True
return True

+ 5
- 3
tests/core/controllers/test_trainer_jittor.py View File

@@ -70,7 +70,7 @@ class TrainJittorConfig:


@pytest.mark.parametrize("driver", ["jittor"])
@pytest.mark.parametrize("device", ["cpu", 1])
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
@pytest.mark.jittor
def test_trainer_jittor(
@@ -134,6 +134,8 @@ def test_trainer_jittor(


if __name__ == "__main__":
# test_trainer_jittor("jittor", None, [RichCallback(100)])
# test_trainer_jittor("jittor", 1, [RichCallback(100)])
# test_trainer_jittor("jittor", "cpu", [RichCallback(100)]) # 测试 CPU
# test_trainer_jittor("jittor", "cuda:0", [RichCallback(100)]) # 测试 单卡 GPU
# test_trainer_jittor("jittor", 1, [RichCallback(100)]) # 测试 指定 GPU
# test_trainer_jittor("jittor", [0, 1], [RichCallback(100)]) # 测试 多卡 GPU
pytest.main(['test_trainer_jittor.py']) # 只运行此模块

Loading…
Cancel
Save