From 005b0e055e184fe2a964410279f9e92750ea4ba0 Mon Sep 17 00:00:00 2001 From: Letian Li <73881739+LetianLee@users.noreply.github.com> Date: Wed, 25 May 2022 12:09:09 +0100 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20jittor=20driver=20?= =?UTF-8?q?=E5=A4=9A=E5=8D=A1=E8=AE=AD=E7=BB=83=20(#418)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../jittor_driver/initialize_jittor_driver.py | 9 ++- fastNLP/core/drivers/jittor_driver/mpi.py | 80 +++++++++++++++---- tests/core/controllers/test_trainer_jittor.py | 8 +- 3 files changed, 76 insertions(+), 21 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py index 4b1fcba7..eff8fcfe 100644 --- a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py @@ -1,5 +1,6 @@ from typing import Union, List +from fastNLP.core.drivers.jittor_driver.mpi import JittorMPIDriver from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -29,7 +30,11 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") # TODO 实现更详细的判断 - if driver == "jittor": + if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: return JittorSingleDriver(model, device, **kwargs) + elif type(device) is int: + return JittorMPIDriver(model, device, **kwargs) + elif type(device) is list: + return JittorMPIDriver(model, device, **kwargs) else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError(f"Device={device}") diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index 4ade3fd1..ee2514e9 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -2,11 +2,14 @@ import os from typing import Optional, Union, Callable, Dict, Tuple from .jittor_driver import JittorDriver +from fastNLP.core.utils import auto_param_call +from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: - import jittor + import jittor as jt __all__ = [ "JittorMPIDriver", @@ -42,7 +45,31 @@ class JittorMPIDriver(JittorDriver): self.outside_mpi = False def setup(self): - pass + self.__fork_with_mpi__() + + def __fork_with_mpi__(self): + import sys + if jt.in_mpi: + # you can mult other process output + if jt.rank != 0: + sys.stdout = open("/dev/null", "w") + return + else: + if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡 + raise NotImplementedError(f"Device={self.parallel_device}") + elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练 + num_procs = 1 + devices = self.parallel_device + elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定 + num_procs = len(self.parallel_device) + devices = str(self.parallel_device)[1:-1] + else: + raise NotImplementedError(f"Device={self.parallel_device}") + print(sys.argv) + cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv) + print("[RUN CMD]:", cmd) + os.system(cmd) + exit(0) def configure_mpi(self): pass @@ -71,25 +98,46 @@ class JittorMPIDriver(JittorDriver): def data_device(self): if self.outside_mpi: return self._data_device - return self.model_device + return self.parallel_device + + def step(self): + # for optimizer in self.optimizers: + # self.grad_scaler.step(optimizer) + # self.grad_scaler.update() + for optimizer in self.optimizers: + optimizer.step() + + def backward(self, loss): + # self.grad_scaler.scale(loss).backward() + for optimizer in self.optimizers: + optimizer.backward(loss) + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad() def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: - pass + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) def get_model_call_fn(self, fn: str) -> Tuple: - pass + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') + return self.model, self.model.execute + else: + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): - pass - - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() + return dataloader def is_global_zero(self): return self.global_rank == 0 @@ -107,4 +155,4 @@ class JittorMPIDriver(JittorDriver): pass def is_distributed(self): - return True \ No newline at end of file + return True diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index c84c24f1..b6cefdf3 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -70,7 +70,7 @@ class TrainJittorConfig: @pytest.mark.parametrize("driver", ["jittor"]) -@pytest.mark.parametrize("device", ["cpu", 1]) +@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.jittor def test_trainer_jittor( @@ -134,6 +134,8 @@ def test_trainer_jittor( if __name__ == "__main__": - # test_trainer_jittor("jittor", None, [RichCallback(100)]) - # test_trainer_jittor("jittor", 1, [RichCallback(100)]) + # test_trainer_jittor("jittor", "cpu", [RichCallback(100)]) # 测试 CPU + # test_trainer_jittor("jittor", "cuda:0", [RichCallback(100)]) # 测试 单卡 GPU + # test_trainer_jittor("jittor", 1, [RichCallback(100)]) # 测试 指定 GPU + # test_trainer_jittor("jittor", [0, 1], [RichCallback(100)]) # 测试 多卡 GPU pytest.main(['test_trainer_jittor.py']) # 只运行此模块