@@ -0,0 +1,9 @@ | |||||
__all__ = [ | |||||
"JittorDriver", | |||||
"JittorSingleDriver", | |||||
"JittorMPIDriver", | |||||
] | |||||
from .jittor_driver import JittorDriver | |||||
from .single_device import JittorSingleDriver | |||||
from .mpi import JittorMPIDriver |
@@ -0,0 +1,31 @@ | |||||
from typing import Union, List | |||||
from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver | |||||
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: | |||||
r""" | |||||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | |||||
在这个函数中,我们会根据用户设置的device来确定JittorDriver的mode。 | |||||
:param driver: 该参数的值应为以下之一:["jittor"]; | |||||
:param device: jittor运行的设备 | |||||
:param model: 训练或者评测的具体的模型; | |||||
:param kwargs: | |||||
:return: 返回一个元组,元组的第一个值是具体的基于 jittor 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||||
先后 driver 的次序的正确问题); | |||||
""" | |||||
if driver not in {"jittor"}: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | |||||
# TODO 实现更详细的判断 | |||||
if driver == "jittor": | |||||
return JittorSingleDriver(model, device, **kwargs) | |||||
else: | |||||
raise NotImplementedError |
@@ -0,0 +1,155 @@ | |||||
import os | |||||
import warnings | |||||
from typing import Optional, Callable, Dict | |||||
from .utils import _build_fp16_env | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.drivers.driver import Driver | |||||
from fastNLP.core.dataloaders import JittorDataLoader | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.utils import apply_to_collection | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt | |||||
from jittor import Module | |||||
from jittor.optim import Optimizer | |||||
_reduces = { | |||||
'max': jt.max, | |||||
'min': jt.min, | |||||
'mean': jt.mean, | |||||
'sum': jt.sum | |||||
} | |||||
class JittorDriver(Driver): | |||||
r""" | |||||
Jittor 框架的 Driver | |||||
""" | |||||
def __init__(self, model, fp16: bool = False, **kwargs): | |||||
if not isinstance(model, Module): | |||||
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly " | |||||
f"`jittor.Module` type.") | |||||
super(JittorDriver, self).__init__(model) | |||||
self.model = model | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||||
self.grad_scaler = _grad_scaler() | |||||
@staticmethod | |||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
# 在fastnlp中实现了JittorDataLoader | |||||
# TODO: 是否允许传入Dataset? | |||||
if is_train: | |||||
if not isinstance(dataloader, JittorDataLoader): | |||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.") | |||||
else: | |||||
if not isinstance(dataloader, Dict): | |||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||||
else: | |||||
for each_dataloader in dataloader.values(): | |||||
if not isinstance(each_dataloader, JittorDataLoader): | |||||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' " | |||||
f"type, not {type(each_dataloader)}.") | |||||
@staticmethod | |||||
def _check_optimizer_legality(optimizers): | |||||
for each_optimizer in optimizers: | |||||
if not isinstance(each_optimizer, Optimizer): | |||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||||
f"not {type(each_optimizer)}.") | |||||
def check_evaluator_mode(self, mode: str): | |||||
model = self.unwrap_model() | |||||
if mode == "validate": | |||||
if not hasattr(model, "validate_step"): | |||||
if hasattr(model, "test_step"): | |||||
logger.warning( | |||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||||
"'validate_step'.") | |||||
else: | |||||
if not hasattr(model, "test_step"): | |||||
if hasattr(model, "validate_step"): | |||||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||||
"'test_step'.") | |||||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | |||||
""" | |||||
保存模型 | |||||
""" | |||||
if model_save_fn is not None: | |||||
outputs = model_save_fn(filepath) | |||||
if outputs is not None: | |||||
jt.save(outputs, filepath) | |||||
else: | |||||
if only_state_dict: | |||||
states = self.model.state_dict() | |||||
else: | |||||
warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.") | |||||
jt.save(states, filepath) | |||||
def load_model(self, filepath: str): | |||||
""" | |||||
加载模型的加载函数; | |||||
:param file_path: 保存文件的文件位置(需要包括文件名); | |||||
:return: 加载后的state_dict | |||||
""" | |||||
if not os.path.exists(filepath): | |||||
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | |||||
return jt.load(filepath) | |||||
def save(self): | |||||
... | |||||
def load(self): | |||||
... | |||||
def get_evaluate_context(self): | |||||
return jt.no_grad | |||||
def get_model_device(self): | |||||
return self.model_device | |||||
@staticmethod | |||||
def tensor_to_numeric(tensor, reduce=None): | |||||
if tensor is None: | |||||
return None | |||||
def _translate(_data): | |||||
# 如果只含有一个元素,则返回元素本身,而非list | |||||
if _data.numel() == 1: | |||||
return _data.item() | |||||
if reduce is None: | |||||
return _data.tolist() | |||||
return _reduces[reduce](_data).item() | |||||
return apply_to_collection( | |||||
data=tensor, | |||||
dtype=jt.Var, | |||||
function=_translate | |||||
) | |||||
def set_model_mode(self, mode: str): | |||||
assert mode in {"train", "eval"} | |||||
getattr(self.model, mode)() | |||||
@property | |||||
def data_device(self): | |||||
return self.model_device | |||||
def move_data_to_device(self, batch: 'jt.Var'): | |||||
""" | |||||
jittor暂时没有提供数据迁移的函数,因此这个函数只是简单地返回batch | |||||
""" | |||||
return batch | |||||
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | |||||
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||||
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) |
@@ -0,0 +1,100 @@ | |||||
import os | |||||
from typing import Optional, Union | |||||
from .jittor_driver import JittorDriver | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.samplers import ReproducibleIterator | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
__all__ = [ | |||||
"JittorMPIDriver", | |||||
] | |||||
class JittorMPIDriver(JittorDriver): | |||||
def __init__( | |||||
self, | |||||
model, | |||||
parallel_device: None, | |||||
is_pull_by_jittor_run: bool = False, | |||||
fp16: bool = False, | |||||
**kwargs | |||||
): | |||||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
self.is_pull_by_jittor_run = is_pull_by_jittor_run | |||||
self.parallel_device = parallel_device | |||||
self.outside_mpi = False | |||||
def setup(self): | |||||
pass | |||||
def configure_mpi(self): | |||||
pass | |||||
@property | |||||
def world_size(self) -> int: | |||||
return self._world_size | |||||
@world_size.setter | |||||
def world_size(self, size: int): | |||||
self._world_size = size | |||||
@property | |||||
def global_rank(self) -> int: | |||||
return self._global_rank | |||||
@global_rank.setter | |||||
def global_rank(self, rank: int) -> None: | |||||
self._global_rank = rank | |||||
@property | |||||
def local_rank(self) -> int: | |||||
return int(os.environ.get("LOCAL_RANK", 0)) | |||||
@property | |||||
def data_device(self): | |||||
if self.outside_mpi: | |||||
return self._data_device | |||||
return self.model_device | |||||
def train_step(self, batch): | |||||
return self._train_step(batch) | |||||
def validate_step(self, batch): | |||||
return self._validate_step(batch) | |||||
def test_step(self, batch): | |||||
return self._test_step(batch) | |||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||||
pass | |||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def is_global_zero(self): | |||||
return self.global_rank == 0 | |||||
def get_no_sync_context(self): | |||||
return self.model.no_sync | |||||
def unwrap_model(self): | |||||
pass | |||||
def get_local_rank(self) -> int: | |||||
return self.local_rank | |||||
def barrier(self): | |||||
pass | |||||
def is_distributed(self): | |||||
return True |
@@ -0,0 +1,127 @@ | |||||
from typing import Dict, Union | |||||
from .jittor_driver import JittorDriver | |||||
from fastNLP.core.utils import auto_param_call | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
__all__ = [ | |||||
"JittorSingleDriver", | |||||
] | |||||
class JittorSingleDriver(JittorDriver): | |||||
r""" | |||||
用于 cpu 和 单卡 gpu 运算 | |||||
TODO: jittor 的 fp16 | |||||
""" | |||||
def __init__(self, model, device=None, fp16: bool = False, **kwargs): | |||||
super(JittorSingleDriver, self).__init__(model, fp16) | |||||
self.model_device = device | |||||
self.local_rank = 0 | |||||
self.global_rank = 0 | |||||
self.world_size = 1 | |||||
if hasattr(self.model, "train_step"): | |||||
self._train_step = self.model.train_step | |||||
self._train_signature_fn = None | |||||
else: | |||||
self._train_step = self.model | |||||
model = self.unwrap_model() | |||||
self._train_signature_fn = model.execute | |||||
if hasattr(self.model, "validate_step"): | |||||
self._validate_step = self.model.validate_step | |||||
self._validate_signature_fn = None | |||||
elif hasattr(self.model, "test_step"): | |||||
self._validate_step = self.model.test_step | |||||
self._validate_signature_fn = self.model.test_step | |||||
else: | |||||
self._validate_step = self.model | |||||
model = self.unwrap_model() | |||||
self._validate_signature_fn = model.execute | |||||
if hasattr(self.model, "test_step"): | |||||
self._test_step = self.model.test_step | |||||
self._test_signature_fn = None | |||||
elif hasattr(self.model, "validate_step"): | |||||
self._test_step = self.model.validate_step | |||||
self._test_signature_fn = self.model.validate_step | |||||
else: | |||||
self._test_step = self.model | |||||
model = self.unwrap_model() | |||||
self._test_signature_fn = model.execute | |||||
def train_step(self, batch) -> Dict: | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
else: | |||||
return self._train_step(batch) | |||||
def step(self): | |||||
""" | |||||
jittor optimizers 的step函数可以传入参数loss | |||||
此时会同时进行 zero_grad 和 backward | |||||
为了统一,这里暂不使用这样的方式 | |||||
""" | |||||
for optimizer in self.optimizers: | |||||
optimizer.step() | |||||
def backward(self, loss): | |||||
for optimizer in self.optimizers: | |||||
optimizer.backward(loss) | |||||
def zero_grad(self, set_to_none=False): | |||||
for optimizer in self.optimizers: | |||||
optimizer.zero_grad() | |||||
def validate_step(self, batch): | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
else: | |||||
return self._validate_step(batch) | |||||
def test_step(self, batch): | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
else: | |||||
return self._test_step(batch) | |||||
def unwrap_model(self): | |||||
return self.model | |||||
def is_distributed(self): | |||||
return False | |||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||||
# reproducible 的相关功能暂时没有实现 | |||||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||||
raise NotImplementedError | |||||
dataloader.batch_sampler = dist_sample | |||||
if isinstance(dist_sampler, ReproducibleIterator): | |||||
raise NotImplementedError | |||||
dataloader.batch_sampler.sampler = dist_sampler | |||||
if reproducible: | |||||
raise NotImplementedError | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
return dataloader | |||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
return dataloader | |||||
else: | |||||
# TODO | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | |||||
batch_size=dataloader.batch_sampler.batch_size, | |||||
drop_last=dataloader.drop_last | |||||
) | |||||
dataloader.batch_sampler = batch_sampler | |||||
return dataloader | |||||
else: | |||||
return dataloader |
@@ -0,0 +1,55 @@ | |||||
from contextlib import ExitStack | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
class DummyGradScaler: | |||||
""" | |||||
用于仿造的GradScaler对象,防止重复写大量的if判断 | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def get_scale(self): | |||||
return 1.0 | |||||
def is_enabled(self): | |||||
return False | |||||
def scale(self, outputs): | |||||
return outputs | |||||
def step(self, optimizer, *args, **kwargs): | |||||
optimizer.step(*args, **kwargs) | |||||
def update(self, new_scale=None): | |||||
pass | |||||
def unscale_(self, optimizer): | |||||
pass | |||||
def load_state_dict(self, state_dict): | |||||
pass | |||||
def state_dict(self): | |||||
return {} | |||||
def _build_fp16_env(dummy=False): | |||||
if dummy: | |||||
auto_cast = ExitStack | |||||
GradScaler = DummyGradScaler | |||||
else: | |||||
raise NotImplementedError("JittorDriver does not support fp16 now.") | |||||
# if not jt.flags.use_cuda: | |||||
# raise RuntimeError("No cuda") | |||||
# if paddle.device.cuda.get_device_capability(0)[0] < 7: | |||||
# log.warning( | |||||
# "NOTE: your device does NOT support faster training with fp16, " | |||||
# "please switch to FP32 which is likely to be faster" | |||||
# ) | |||||
# from paddle.amp import auto_cast, GradScaler | |||||
return auto_cast, GradScaler |