@@ -0,0 +1,9 @@ | |||
__all__ = [ | |||
"JittorDriver", | |||
"JittorSingleDriver", | |||
"JittorMPIDriver", | |||
] | |||
from .jittor_driver import JittorDriver | |||
from .single_device import JittorSingleDriver | |||
from .mpi import JittorMPIDriver |
@@ -0,0 +1,31 @@ | |||
from typing import Union, List | |||
from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver | |||
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: | |||
r""" | |||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | |||
在这个函数中,我们会根据用户设置的device来确定JittorDriver的mode。 | |||
:param driver: 该参数的值应为以下之一:["jittor"]; | |||
:param device: jittor运行的设备 | |||
:param model: 训练或者评测的具体的模型; | |||
:param kwargs: | |||
:return: 返回一个元组,元组的第一个值是具体的基于 jittor 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||
先后 driver 的次序的正确问题); | |||
""" | |||
if driver not in {"jittor"}: | |||
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | |||
# TODO 实现更详细的判断 | |||
if driver == "jittor": | |||
return JittorSingleDriver(model, device, **kwargs) | |||
else: | |||
raise NotImplementedError |
@@ -0,0 +1,155 @@ | |||
import os | |||
import warnings | |||
from typing import Optional, Callable, Dict | |||
from .utils import _build_fp16_env | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from jittor import Module | |||
from jittor.optim import Optimizer | |||
_reduces = { | |||
'max': jt.max, | |||
'min': jt.min, | |||
'mean': jt.mean, | |||
'sum': jt.sum | |||
} | |||
class JittorDriver(Driver): | |||
r""" | |||
Jittor 框架的 Driver | |||
""" | |||
def __init__(self, model, fp16: bool = False, **kwargs): | |||
if not isinstance(model, Module): | |||
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly " | |||
f"`jittor.Module` type.") | |||
super(JittorDriver, self).__init__(model) | |||
self.model = model | |||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||
self.grad_scaler = _grad_scaler() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
# 在fastnlp中实现了JittorDataLoader | |||
# TODO: 是否允许传入Dataset? | |||
if is_train: | |||
if not isinstance(dataloader, JittorDataLoader): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
else: | |||
for each_dataloader in dataloader.values(): | |||
if not isinstance(each_dataloader, JittorDataLoader): | |||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' " | |||
f"type, not {type(each_dataloader)}.") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, Optimizer): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
def check_evaluator_mode(self, mode: str): | |||
model = self.unwrap_model() | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | |||
""" | |||
保存模型 | |||
""" | |||
if model_save_fn is not None: | |||
outputs = model_save_fn(filepath) | |||
if outputs is not None: | |||
jt.save(outputs, filepath) | |||
else: | |||
if only_state_dict: | |||
states = self.model.state_dict() | |||
else: | |||
warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.") | |||
jt.save(states, filepath) | |||
def load_model(self, filepath: str): | |||
""" | |||
加载模型的加载函数; | |||
:param file_path: 保存文件的文件位置(需要包括文件名); | |||
:return: 加载后的state_dict | |||
""" | |||
if not os.path.exists(filepath): | |||
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | |||
return jt.load(filepath) | |||
def save(self): | |||
... | |||
def load(self): | |||
... | |||
def get_evaluate_context(self): | |||
return jt.no_grad | |||
def get_model_device(self): | |||
return self.model_device | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce=None): | |||
if tensor is None: | |||
return None | |||
def _translate(_data): | |||
# 如果只含有一个元素,则返回元素本身,而非list | |||
if _data.numel() == 1: | |||
return _data.item() | |||
if reduce is None: | |||
return _data.tolist() | |||
return _reduces[reduce](_data).item() | |||
return apply_to_collection( | |||
data=tensor, | |||
dtype=jt.Var, | |||
function=_translate | |||
) | |||
def set_model_mode(self, mode: str): | |||
assert mode in {"train", "eval"} | |||
getattr(self.model, mode)() | |||
@property | |||
def data_device(self): | |||
return self.model_device | |||
def move_data_to_device(self, batch: 'jt.Var'): | |||
""" | |||
jittor暂时没有提供数据迁移的函数,因此这个函数只是简单地返回batch | |||
""" | |||
return batch | |||
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | |||
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) |
@@ -0,0 +1,100 @@ | |||
import os | |||
from typing import Optional, Union | |||
from .jittor_driver import JittorDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import ReproducibleIterator | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
__all__ = [ | |||
"JittorMPIDriver", | |||
] | |||
class JittorMPIDriver(JittorDriver): | |||
def __init__( | |||
self, | |||
model, | |||
parallel_device: None, | |||
is_pull_by_jittor_run: bool = False, | |||
fp16: bool = False, | |||
**kwargs | |||
): | |||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
self.is_pull_by_jittor_run = is_pull_by_jittor_run | |||
self.parallel_device = parallel_device | |||
self.outside_mpi = False | |||
def setup(self): | |||
pass | |||
def configure_mpi(self): | |||
pass | |||
@property | |||
def world_size(self) -> int: | |||
return self._world_size | |||
@world_size.setter | |||
def world_size(self, size: int): | |||
self._world_size = size | |||
@property | |||
def global_rank(self) -> int: | |||
return self._global_rank | |||
@global_rank.setter | |||
def global_rank(self, rank: int) -> None: | |||
self._global_rank = rank | |||
@property | |||
def local_rank(self) -> int: | |||
return int(os.environ.get("LOCAL_RANK", 0)) | |||
@property | |||
def data_device(self): | |||
if self.outside_mpi: | |||
return self._data_device | |||
return self.model_device | |||
def train_step(self, batch): | |||
return self._train_step(batch) | |||
def validate_step(self, batch): | |||
return self._validate_step(batch) | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||
pass | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
def get_no_sync_context(self): | |||
return self.model.no_sync | |||
def unwrap_model(self): | |||
pass | |||
def get_local_rank(self) -> int: | |||
return self.local_rank | |||
def barrier(self): | |||
pass | |||
def is_distributed(self): | |||
return True |
@@ -0,0 +1,127 @@ | |||
from typing import Dict, Union | |||
from .jittor_driver import JittorDriver | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
__all__ = [ | |||
"JittorSingleDriver", | |||
] | |||
class JittorSingleDriver(JittorDriver): | |||
r""" | |||
用于 cpu 和 单卡 gpu 运算 | |||
TODO: jittor 的 fp16 | |||
""" | |||
def __init__(self, model, device=None, fp16: bool = False, **kwargs): | |||
super(JittorSingleDriver, self).__init__(model, fp16) | |||
self.model_device = device | |||
self.local_rank = 0 | |||
self.global_rank = 0 | |||
self.world_size = 1 | |||
if hasattr(self.model, "train_step"): | |||
self._train_step = self.model.train_step | |||
self._train_signature_fn = None | |||
else: | |||
self._train_step = self.model | |||
model = self.unwrap_model() | |||
self._train_signature_fn = model.execute | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
self._validate_signature_fn = self.model.test_step | |||
else: | |||
self._validate_step = self.model | |||
model = self.unwrap_model() | |||
self._validate_signature_fn = model.execute | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
self._test_signature_fn = self.model.validate_step | |||
else: | |||
self._test_step = self.model | |||
model = self.unwrap_model() | |||
self._test_signature_fn = model.execute | |||
def train_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
def step(self): | |||
""" | |||
jittor optimizers 的step函数可以传入参数loss | |||
此时会同时进行 zero_grad 和 backward | |||
为了统一,这里暂不使用这样的方式 | |||
""" | |||
for optimizer in self.optimizers: | |||
optimizer.step() | |||
def backward(self, loss): | |||
for optimizer in self.optimizers: | |||
optimizer.backward(loss) | |||
def zero_grad(self, set_to_none=False): | |||
for optimizer in self.optimizers: | |||
optimizer.zero_grad() | |||
def validate_step(self, batch): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
def unwrap_model(self): | |||
return self.model | |||
def is_distributed(self): | |||
return False | |||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||
# reproducible 的相关功能暂时没有实现 | |||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler = dist_sample | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
raise NotImplementedError | |||
dataloader.batch_sampler.sampler = dist_sampler | |||
if reproducible: | |||
raise NotImplementedError | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
) | |||
dataloader.batch_sampler = batch_sampler | |||
return dataloader | |||
else: | |||
return dataloader |
@@ -0,0 +1,55 @@ | |||
from contextlib import ExitStack | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
class DummyGradScaler: | |||
""" | |||
用于仿造的GradScaler对象,防止重复写大量的if判断 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def get_scale(self): | |||
return 1.0 | |||
def is_enabled(self): | |||
return False | |||
def scale(self, outputs): | |||
return outputs | |||
def step(self, optimizer, *args, **kwargs): | |||
optimizer.step(*args, **kwargs) | |||
def update(self, new_scale=None): | |||
pass | |||
def unscale_(self, optimizer): | |||
pass | |||
def load_state_dict(self, state_dict): | |||
pass | |||
def state_dict(self): | |||
return {} | |||
def _build_fp16_env(dummy=False): | |||
if dummy: | |||
auto_cast = ExitStack | |||
GradScaler = DummyGradScaler | |||
else: | |||
raise NotImplementedError("JittorDriver does not support fp16 now.") | |||
# if not jt.flags.use_cuda: | |||
# raise RuntimeError("No cuda") | |||
# if paddle.device.cuda.get_device_capability(0)[0] < 7: | |||
# log.warning( | |||
# "NOTE: your device does NOT support faster training with fp16, " | |||
# "please switch to FP32 which is likely to be faster" | |||
# ) | |||
# from paddle.amp import auto_cast, GradScaler | |||
return auto_cast, GradScaler |