Browse Source

提交core/drivers/jittor_driver/

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
961f2da7eb
6 changed files with 477 additions and 0 deletions
  1. +9
    -0
      fastNLP/core/drivers/jittor_driver/__init__.py
  2. +31
    -0
      fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py
  3. +155
    -0
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  4. +100
    -0
      fastNLP/core/drivers/jittor_driver/mpi.py
  5. +127
    -0
      fastNLP/core/drivers/jittor_driver/single_device.py
  6. +55
    -0
      fastNLP/core/drivers/jittor_driver/utils.py

+ 9
- 0
fastNLP/core/drivers/jittor_driver/__init__.py View File

@@ -0,0 +1,9 @@
__all__ = [
"JittorDriver",
"JittorSingleDriver",
"JittorMPIDriver",
]

from .jittor_driver import JittorDriver
from .single_device import JittorSingleDriver
from .mpi import JittorMPIDriver

+ 31
- 0
fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py View File

@@ -0,0 +1,31 @@
from typing import Union, List

from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR

if _NEED_IMPORT_JITTOR:
import jittor

def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver:
r"""
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去;
在这个函数中,我们会根据用户设置的device来确定JittorDriver的mode。

:param driver: 该参数的值应为以下之一:["jittor"];
:param device: jittor运行的设备
:param model: 训练或者评测的具体的模型;
:param kwargs:

:return: 返回一个元组,元组的第一个值是具体的基于 jittor 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中
先后 driver 的次序的正确问题);
"""

if driver not in {"jittor"}:
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].")

# TODO 实现更详细的判断
if driver == "jittor":
return JittorSingleDriver(model, device, **kwargs)
else:
raise NotImplementedError

+ 155
- 0
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -0,0 +1,155 @@
import os
import warnings
from typing import Optional, Callable, Dict

from .utils import _build_fp16_env
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection

if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor import Module
from jittor.optim import Optimizer

_reduces = {
'max': jt.max,
'min': jt.min,
'mean': jt.mean,
'sum': jt.sum
}


class JittorDriver(Driver):
r"""
Jittor 框架的 Driver
"""

def __init__(self, model, fp16: bool = False, **kwargs):
if not isinstance(model, Module):
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly "
f"`jittor.Module` type.")
super(JittorDriver, self).__init__(model)

self.model = model

self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
# 在fastnlp中实现了JittorDataLoader
# TODO: 是否允许传入Dataset?
if is_train:
if not isinstance(dataloader, JittorDataLoader):
raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.")
else:
if not isinstance(dataloader, Dict):
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
else:
for each_dataloader in dataloader.values():
if not isinstance(each_dataloader, JittorDataLoader):
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' "
f"type, not {type(each_dataloader)}.")

@staticmethod
def _check_optimizer_legality(optimizers):
for each_optimizer in optimizers:
if not isinstance(each_optimizer, Optimizer):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, "
f"not {type(each_optimizer)}.")

def check_evaluator_mode(self, mode: str):
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "validate_step"):
if hasattr(model, "test_step"):
logger.warning(
"Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.")

else:
if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"):
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'mode=test', we are going to use 'validate_step' to substitute for"
"'test_step'.")

def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None):
"""
保存模型
"""
if model_save_fn is not None:
outputs = model_save_fn(filepath)
if outputs is not None:
jt.save(outputs, filepath)
else:
if only_state_dict:
states = self.model.state_dict()
else:
warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.")
jt.save(states, filepath)

def load_model(self, filepath: str):
"""
加载模型的加载函数;

:param file_path: 保存文件的文件位置(需要包括文件名);
:return: 加载后的state_dict
"""
if not os.path.exists(filepath):
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath))
return jt.load(filepath)

def save(self):
...

def load(self):
...

def get_evaluate_context(self):
return jt.no_grad

def get_model_device(self):
return self.model_device

@staticmethod
def tensor_to_numeric(tensor, reduce=None):
if tensor is None:
return None

def _translate(_data):
# 如果只含有一个元素,则返回元素本身,而非list
if _data.numel() == 1:
return _data.item()
if reduce is None:
return _data.tolist()
return _reduces[reduce](_data).item()

return apply_to_collection(
data=tensor,
dtype=jt.Var,
function=_translate
)

def set_model_mode(self, mode: str):
assert mode in {"train", "eval"}
getattr(self.model, mode)()

@property
def data_device(self):
return self.model_device

def move_data_to_device(self, batch: 'jt.Var'):
"""
jittor暂时没有提供数据迁移的函数,因此这个函数只是简单地返回batch
"""
return batch

# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx):
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
# dataloader.batch_sampler.set_epoch(cur_epoch_idx)

+ 100
- 0
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -0,0 +1,100 @@
import os
from typing import Optional, Union

from .jittor_driver import JittorDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleIterator

if _NEED_IMPORT_JITTOR:
import jittor

__all__ = [
"JittorMPIDriver",
]

class JittorMPIDriver(JittorDriver):
def __init__(
self,
model,
parallel_device: None,
is_pull_by_jittor_run: bool = False,
fp16: bool = False,
**kwargs
):

super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs)

self.is_pull_by_jittor_run = is_pull_by_jittor_run
self.parallel_device = parallel_device

self.outside_mpi = False

def setup(self):
pass

def configure_mpi(self):
pass

@property
def world_size(self) -> int:
return self._world_size

@world_size.setter
def world_size(self, size: int):
self._world_size = size

@property
def global_rank(self) -> int:
return self._global_rank

@global_rank.setter
def global_rank(self, rank: int) -> None:
self._global_rank = rank

@property
def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))

@property
def data_device(self):
if self.outside_mpi:
return self._data_device
return self.model_device

def train_step(self, batch):
return self._train_step(batch)

def validate_step(self, batch):
return self._validate_step(batch)

def test_step(self, batch):
return self._test_step(batch)

def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False):
pass

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def is_global_zero(self):
return self.global_rank == 0

def get_no_sync_context(self):
return self.model.no_sync

def unwrap_model(self):
pass

def get_local_rank(self) -> int:
return self.local_rank

def barrier(self):
pass

def is_distributed(self):
return True

+ 127
- 0
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -0,0 +1,127 @@
from typing import Dict, Union

from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator

if _NEED_IMPORT_JITTOR:
import jittor

__all__ = [
"JittorSingleDriver",
]

class JittorSingleDriver(JittorDriver):
r"""
用于 cpu 和 单卡 gpu 运算
TODO: jittor 的 fp16
"""

def __init__(self, model, device=None, fp16: bool = False, **kwargs):
super(JittorSingleDriver, self).__init__(model, fp16)

self.model_device = device

self.local_rank = 0
self.global_rank = 0
self.world_size = 1

if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
model = self.unwrap_model()
self._train_signature_fn = model.execute

if hasattr(self.model, "validate_step"):
self._validate_step = self.model.validate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.execute

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "validate_step"):
self._test_step = self.model.validate_step
self._test_signature_fn = self.model.validate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.execute

def train_step(self, batch) -> Dict:
if isinstance(batch, Dict):
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)

def step(self):
"""
jittor optimizers 的step函数可以传入参数loss
此时会同时进行 zero_grad 和 backward
为了统一,这里暂不使用这样的方式
"""
for optimizer in self.optimizers:
optimizer.step()

def backward(self, loss):
for optimizer in self.optimizers:
optimizer.backward(loss)

def zero_grad(self, set_to_none=False):
for optimizer in self.optimizers:
optimizer.zero_grad()

def validate_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch):

if isinstance(batch, Dict):
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)

def unwrap_model(self):
return self.model

def is_distributed(self):
return False

def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False):
# reproducible 的相关功能暂时没有实现
if isinstance(dist_sampler, ReproducibleBatchSampler):
raise NotImplementedError
dataloader.batch_sampler = dist_sample
if isinstance(dist_sampler, ReproducibleIterator):
raise NotImplementedError
dataloader.batch_sampler.sampler = dist_sampler

if reproducible:
raise NotImplementedError
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
return dataloader
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
return dataloader
else:
# TODO
batch_sampler = ReproducibleBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last
)
dataloader.batch_sampler = batch_sampler
return dataloader
else:
return dataloader

+ 55
- 0
fastNLP/core/drivers/jittor_driver/utils.py View File

@@ -0,0 +1,55 @@
from contextlib import ExitStack

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR

if _NEED_IMPORT_JITTOR:
import jittor

class DummyGradScaler:
"""
用于仿造的GradScaler对象,防止重复写大量的if判断

"""
def __init__(self, *args, **kwargs):
pass

def get_scale(self):
return 1.0

def is_enabled(self):
return False

def scale(self, outputs):
return outputs

def step(self, optimizer, *args, **kwargs):
optimizer.step(*args, **kwargs)

def update(self, new_scale=None):
pass

def unscale_(self, optimizer):
pass

def load_state_dict(self, state_dict):
pass

def state_dict(self):
return {}


def _build_fp16_env(dummy=False):
if dummy:
auto_cast = ExitStack
GradScaler = DummyGradScaler
else:
raise NotImplementedError("JittorDriver does not support fp16 now.")
# if not jt.flags.use_cuda:
# raise RuntimeError("No cuda")
# if paddle.device.cuda.get_device_capability(0)[0] < 7:
# log.warning(
# "NOTE: your device does NOT support faster training with fp16, "
# "please switch to FP32 which is likely to be faster"
# )
# from paddle.amp import auto_cast, GradScaler
return auto_cast, GradScaler

Loading…
Cancel
Save