|
|
@@ -2,11 +2,14 @@ import os |
|
|
|
from typing import Optional, Union, Callable, Dict, Tuple |
|
|
|
|
|
|
|
from .jittor_driver import JittorDriver |
|
|
|
from fastNLP.core.utils import auto_param_call |
|
|
|
from fastNLP.core.utils.utils import _get_fun_msg |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR |
|
|
|
from fastNLP.core.samplers import ReproducibleSampler |
|
|
|
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler |
|
|
|
from fastNLP.core.log import logger |
|
|
|
|
|
|
|
if _NEED_IMPORT_JITTOR: |
|
|
|
import jittor |
|
|
|
import jittor as jt |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
"JittorMPIDriver", |
|
|
@@ -42,7 +45,31 @@ class JittorMPIDriver(JittorDriver): |
|
|
|
self.outside_mpi = False |
|
|
|
|
|
|
|
def setup(self): |
|
|
|
pass |
|
|
|
self.__fork_with_mpi__() |
|
|
|
|
|
|
|
def __fork_with_mpi__(self): |
|
|
|
import sys |
|
|
|
if jt.in_mpi: |
|
|
|
# you can mult other process output |
|
|
|
if jt.rank != 0: |
|
|
|
sys.stdout = open("/dev/null", "w") |
|
|
|
return |
|
|
|
else: |
|
|
|
if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡 |
|
|
|
raise NotImplementedError(f"Device={self.parallel_device}") |
|
|
|
elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练 |
|
|
|
num_procs = 1 |
|
|
|
devices = self.parallel_device |
|
|
|
elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定 |
|
|
|
num_procs = len(self.parallel_device) |
|
|
|
devices = str(self.parallel_device)[1:-1] |
|
|
|
else: |
|
|
|
raise NotImplementedError(f"Device={self.parallel_device}") |
|
|
|
print(sys.argv) |
|
|
|
cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv) |
|
|
|
print("[RUN CMD]:", cmd) |
|
|
|
os.system(cmd) |
|
|
|
exit(0) |
|
|
|
|
|
|
|
def configure_mpi(self): |
|
|
|
pass |
|
|
@@ -71,25 +98,46 @@ class JittorMPIDriver(JittorDriver): |
|
|
|
def data_device(self): |
|
|
|
if self.outside_mpi: |
|
|
|
return self._data_device |
|
|
|
return self.model_device |
|
|
|
return self.parallel_device |
|
|
|
|
|
|
|
def step(self): |
|
|
|
# for optimizer in self.optimizers: |
|
|
|
# self.grad_scaler.step(optimizer) |
|
|
|
# self.grad_scaler.update() |
|
|
|
for optimizer in self.optimizers: |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
def backward(self, loss): |
|
|
|
# self.grad_scaler.scale(loss).backward() |
|
|
|
for optimizer in self.optimizers: |
|
|
|
optimizer.backward(loss) |
|
|
|
|
|
|
|
def zero_grad(self): |
|
|
|
for optimizer in self.optimizers: |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: |
|
|
|
pass |
|
|
|
if isinstance(batch, Dict) and not self.wo_auto_param_call: |
|
|
|
return auto_param_call(fn, batch, signature_fn=signature_fn) |
|
|
|
else: |
|
|
|
return fn(batch) |
|
|
|
|
|
|
|
def get_model_call_fn(self, fn: str) -> Tuple: |
|
|
|
pass |
|
|
|
if hasattr(self.model, fn): |
|
|
|
fn = getattr(self.model, fn) |
|
|
|
if not callable(fn): |
|
|
|
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") |
|
|
|
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') |
|
|
|
return fn, None |
|
|
|
elif fn in {"train_step", "evaluate_step"}: |
|
|
|
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') |
|
|
|
return self.model, self.model.execute |
|
|
|
else: |
|
|
|
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") |
|
|
|
|
|
|
|
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], |
|
|
|
reproducible: bool = False, sampler_or_batch_sampler=None): |
|
|
|
pass |
|
|
|
|
|
|
|
def backward(self, loss): |
|
|
|
self.grad_scaler.scale(loss).backward() |
|
|
|
|
|
|
|
def step(self): |
|
|
|
for optimizer in self.optimizers: |
|
|
|
self.grad_scaler.step(optimizer) |
|
|
|
self.grad_scaler.update() |
|
|
|
return dataloader |
|
|
|
|
|
|
|
def is_global_zero(self): |
|
|
|
return self.global_rank == 0 |
|
|
@@ -107,4 +155,4 @@ class JittorMPIDriver(JittorDriver): |
|
|
|
pass |
|
|
|
|
|
|
|
def is_distributed(self): |
|
|
|
return True |
|
|
|
return True |