diff --git a/fastNLP/core/drivers/jittor_driver/__init__.py b/fastNLP/core/drivers/jittor_driver/__init__.py new file mode 100644 index 00000000..701fb04b --- /dev/null +++ b/fastNLP/core/drivers/jittor_driver/__init__.py @@ -0,0 +1,9 @@ +__all__ = [ + "JittorDriver", + "JittorSingleDriver", + "JittorMPIDriver", +] + +from .jittor_driver import JittorDriver +from .single_device import JittorSingleDriver +from .mpi import JittorMPIDriver \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py new file mode 100644 index 00000000..e2d8aadb --- /dev/null +++ b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py @@ -0,0 +1,31 @@ +from typing import Union, List + +from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver +from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + +if _NEED_IMPORT_JITTOR: + import jittor + +def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: + r""" + 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; + 在这个函数中,我们会根据用户设置的device来确定JittorDriver的mode。 + + :param driver: 该参数的值应为以下之一:["jittor"]; + :param device: jittor运行的设备 + :param model: 训练或者评测的具体的模型; + :param kwargs: + + :return: 返回一个元组,元组的第一个值是具体的基于 jittor 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 + 先后 driver 的次序的正确问题); + """ + + if driver not in {"jittor"}: + raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") + + # TODO 实现更详细的判断 + if driver == "jittor": + return JittorSingleDriver(model, device, **kwargs) + else: + raise NotImplementedError \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py new file mode 100644 index 00000000..a8ad32e8 --- /dev/null +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -0,0 +1,155 @@ +import os +import warnings +from typing import Optional, Callable, Dict + +from .utils import _build_fp16_env +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +from fastNLP.core.drivers.driver import Driver +from fastNLP.core.dataloaders import JittorDataLoader +from fastNLP.core.log import logger +from fastNLP.core.utils import apply_to_collection + +if _NEED_IMPORT_JITTOR: + import jittor as jt + from jittor import Module + from jittor.optim import Optimizer + + _reduces = { + 'max': jt.max, + 'min': jt.min, + 'mean': jt.mean, + 'sum': jt.sum + } + + +class JittorDriver(Driver): + r""" + Jittor 框架的 Driver + """ + + def __init__(self, model, fp16: bool = False, **kwargs): + if not isinstance(model, Module): + raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly " + f"`jittor.Module` type.") + super(JittorDriver, self).__init__(model) + + self.model = model + + self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) + self.grad_scaler = _grad_scaler() + + @staticmethod + def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): + # 在fastnlp中实现了JittorDataLoader + # TODO: 是否允许传入Dataset? + if is_train: + if not isinstance(dataloader, JittorDataLoader): + raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.") + else: + if not isinstance(dataloader, Dict): + raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") + else: + for each_dataloader in dataloader.values(): + if not isinstance(each_dataloader, JittorDataLoader): + raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' " + f"type, not {type(each_dataloader)}.") + + @staticmethod + def _check_optimizer_legality(optimizers): + for each_optimizer in optimizers: + if not isinstance(each_optimizer, Optimizer): + raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " + f"not {type(each_optimizer)}.") + + def check_evaluator_mode(self, mode: str): + model = self.unwrap_model() + if mode == "validate": + if not hasattr(model, "validate_step"): + if hasattr(model, "test_step"): + logger.warning( + "Your model does not have 'validate_step' method but has 'test_step' method, but you" + "are using 'mode=validate', we are going to use 'test_step' to substitute for" + "'validate_step'.") + + else: + if not hasattr(model, "test_step"): + if hasattr(model, "validate_step"): + logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" + "are using 'mode=test', we are going to use 'validate_step' to substitute for" + "'test_step'.") + + def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): + """ + 保存模型 + """ + if model_save_fn is not None: + outputs = model_save_fn(filepath) + if outputs is not None: + jt.save(outputs, filepath) + else: + if only_state_dict: + states = self.model.state_dict() + else: + warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.") + jt.save(states, filepath) + + def load_model(self, filepath: str): + """ + 加载模型的加载函数; + + :param file_path: 保存文件的文件位置(需要包括文件名); + :return: 加载后的state_dict + """ + if not os.path.exists(filepath): + raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) + return jt.load(filepath) + + def save(self): + ... + + def load(self): + ... + + def get_evaluate_context(self): + return jt.no_grad + + def get_model_device(self): + return self.model_device + + @staticmethod + def tensor_to_numeric(tensor, reduce=None): + if tensor is None: + return None + + def _translate(_data): + # 如果只含有一个元素,则返回元素本身,而非list + if _data.numel() == 1: + return _data.item() + if reduce is None: + return _data.tolist() + return _reduces[reduce](_data).item() + + return apply_to_collection( + data=tensor, + dtype=jt.Var, + function=_translate + ) + + def set_model_mode(self, mode: str): + assert mode in {"train", "eval"} + getattr(self.model, mode)() + + @property + def data_device(self): + return self.model_device + + def move_data_to_device(self, batch: 'jt.Var'): + """ + jittor暂时没有提供数据迁移的函数,因此这个函数只是简单地返回batch + """ + return batch + + # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): + # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; + # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): + # dataloader.batch_sampler.set_epoch(cur_epoch_idx) \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py new file mode 100644 index 00000000..b02249f7 --- /dev/null +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -0,0 +1,100 @@ +import os +from typing import Optional, Union + +from .jittor_driver import JittorDriver +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +from fastNLP.core.samplers import ReproducibleIterator + +if _NEED_IMPORT_JITTOR: + import jittor + +__all__ = [ + "JittorMPIDriver", +] + +class JittorMPIDriver(JittorDriver): + def __init__( + self, + model, + parallel_device: None, + is_pull_by_jittor_run: bool = False, + fp16: bool = False, + **kwargs + ): + + super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) + + self.is_pull_by_jittor_run = is_pull_by_jittor_run + self.parallel_device = parallel_device + + self.outside_mpi = False + + def setup(self): + pass + + def configure_mpi(self): + pass + + @property + def world_size(self) -> int: + return self._world_size + + @world_size.setter + def world_size(self, size: int): + self._world_size = size + + @property + def global_rank(self) -> int: + return self._global_rank + + @global_rank.setter + def global_rank(self, rank: int) -> None: + self._global_rank = rank + + @property + def local_rank(self) -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def data_device(self): + if self.outside_mpi: + return self._data_device + return self.model_device + + def train_step(self, batch): + return self._train_step(batch) + + def validate_step(self, batch): + return self._validate_step(batch) + + def test_step(self, batch): + return self._test_step(batch) + + def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): + pass + + def backward(self, loss): + self.grad_scaler.scale(loss).backward() + + def step(self): + for optimizer in self.optimizers: + self.grad_scaler.step(optimizer) + self.grad_scaler.update() + + def is_global_zero(self): + return self.global_rank == 0 + + def get_no_sync_context(self): + return self.model.no_sync + + def unwrap_model(self): + pass + + def get_local_rank(self) -> int: + return self.local_rank + + def barrier(self): + pass + + def is_distributed(self): + return True \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py new file mode 100644 index 00000000..452fa85c --- /dev/null +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -0,0 +1,127 @@ +from typing import Dict, Union + +from .jittor_driver import JittorDriver +from fastNLP.core.utils import auto_param_call +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator + +if _NEED_IMPORT_JITTOR: + import jittor + +__all__ = [ + "JittorSingleDriver", +] + +class JittorSingleDriver(JittorDriver): + r""" + 用于 cpu 和 单卡 gpu 运算 + TODO: jittor 的 fp16 + """ + + def __init__(self, model, device=None, fp16: bool = False, **kwargs): + super(JittorSingleDriver, self).__init__(model, fp16) + + self.model_device = device + + self.local_rank = 0 + self.global_rank = 0 + self.world_size = 1 + + if hasattr(self.model, "train_step"): + self._train_step = self.model.train_step + self._train_signature_fn = None + else: + self._train_step = self.model + model = self.unwrap_model() + self._train_signature_fn = model.execute + + if hasattr(self.model, "validate_step"): + self._validate_step = self.model.validate_step + self._validate_signature_fn = None + elif hasattr(self.model, "test_step"): + self._validate_step = self.model.test_step + self._validate_signature_fn = self.model.test_step + else: + self._validate_step = self.model + model = self.unwrap_model() + self._validate_signature_fn = model.execute + + if hasattr(self.model, "test_step"): + self._test_step = self.model.test_step + self._test_signature_fn = None + elif hasattr(self.model, "validate_step"): + self._test_step = self.model.validate_step + self._test_signature_fn = self.model.validate_step + else: + self._test_step = self.model + model = self.unwrap_model() + self._test_signature_fn = model.execute + + def train_step(self, batch) -> Dict: + if isinstance(batch, Dict): + return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) + else: + return self._train_step(batch) + + def step(self): + """ + jittor optimizers 的step函数可以传入参数loss + 此时会同时进行 zero_grad 和 backward + 为了统一,这里暂不使用这样的方式 + """ + for optimizer in self.optimizers: + optimizer.step() + + def backward(self, loss): + for optimizer in self.optimizers: + optimizer.backward(loss) + + def zero_grad(self, set_to_none=False): + for optimizer in self.optimizers: + optimizer.zero_grad() + + def validate_step(self, batch): + if isinstance(batch, Dict): + return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) + else: + return self._validate_step(batch) + + def test_step(self, batch): + + if isinstance(batch, Dict): + return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) + else: + return self._test_step(batch) + + def unwrap_model(self): + return self.model + + def is_distributed(self): + return False + + def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): + # reproducible 的相关功能暂时没有实现 + if isinstance(dist_sampler, ReproducibleBatchSampler): + raise NotImplementedError + dataloader.batch_sampler = dist_sample + if isinstance(dist_sampler, ReproducibleIterator): + raise NotImplementedError + dataloader.batch_sampler.sampler = dist_sampler + + if reproducible: + raise NotImplementedError + if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): + return dataloader + elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): + return dataloader + else: + # TODO + batch_sampler = ReproducibleBatchSampler( + batch_sampler=dataloader.batch_sampler, + batch_size=dataloader.batch_sampler.batch_size, + drop_last=dataloader.drop_last + ) + dataloader.batch_sampler = batch_sampler + return dataloader + else: + return dataloader diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py new file mode 100644 index 00000000..f8ddbbe1 --- /dev/null +++ b/fastNLP/core/drivers/jittor_driver/utils.py @@ -0,0 +1,55 @@ +from contextlib import ExitStack + +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + +if _NEED_IMPORT_JITTOR: + import jittor + +class DummyGradScaler: + """ + 用于仿造的GradScaler对象,防止重复写大量的if判断 + + """ + def __init__(self, *args, **kwargs): + pass + + def get_scale(self): + return 1.0 + + def is_enabled(self): + return False + + def scale(self, outputs): + return outputs + + def step(self, optimizer, *args, **kwargs): + optimizer.step(*args, **kwargs) + + def update(self, new_scale=None): + pass + + def unscale_(self, optimizer): + pass + + def load_state_dict(self, state_dict): + pass + + def state_dict(self): + return {} + + +def _build_fp16_env(dummy=False): + if dummy: + auto_cast = ExitStack + GradScaler = DummyGradScaler + else: + raise NotImplementedError("JittorDriver does not support fp16 now.") + # if not jt.flags.use_cuda: + # raise RuntimeError("No cuda") + # if paddle.device.cuda.get_device_capability(0)[0] < 7: + # log.warning( + # "NOTE: your device does NOT support faster training with fp16, " + # "please switch to FP32 which is likely to be faster" + # ) + # from paddle.amp import auto_cast, GradScaler + return auto_cast, GradScaler \ No newline at end of file