From 78596ea11ce886e28fe97fef2275da9e7fa2f7d6 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 22 Jun 2022 16:49:03 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BA=20Trainer=20=E7=9A=84driver=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E5=A2=9E=E5=8A=A0=20'auto'=20=E9=80=89?= =?UTF-8?q?=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 11 ++++++----- fastNLP/core/drivers/choose_driver.py | 11 +++++++++++ fastNLP/core/utils/__init__.py | 9 ++++++--- fastNLP/core/utils/jittor_utils.py | 14 +++++++++++++- fastNLP/core/utils/paddle_utils.py | 15 ++++++++++++++- fastNLP/core/utils/torch_utils.py | 15 ++++++++++++++- 6 files changed, 64 insertions(+), 11 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 0f22e63c..7a84598e 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -55,9 +55,10 @@ class Trainer(TrainerEventTrigger): 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 - :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等 - 国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` - 的设置; + :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["auto", "torch", "paddle", "jittor", "fairscale"]。其值为 ``"auto"`` 时, + **FastNLP** 会根据传入模型的类型自行判断使用哪一种模式;其值为 "torch" 时,表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``; + 其值为 "paddle" 时,表示使用 ``PaddleSingleDriver`` 或者 ``PaddleFleetDriver``;其值为 "jittor" 时,表示使用 ``JittorSingleDriver`` + 或者 ``JittorMPIDriver``;其值为 "fairscale" 时,表示使用 ``FairScaleDriver``。在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; .. warning:: @@ -81,7 +82,7 @@ class Trainer(TrainerEventTrigger): device 的可选输入如下所示: - * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; + * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'gpu:0' 等; * *torch.device*: 例如 'torch.device("cuda:0")'; * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; @@ -365,9 +366,9 @@ class Trainer(TrainerEventTrigger): def __init__( self, model, - driver, train_dataloader, optimizers, + driver: str = "auto", device: Optional[Union[int, List[int], str]] = "cpu", n_epochs: int = 20, evaluate_dataloaders=None, diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py index 4be1e502..75df97c4 100644 --- a/fastNLP/core/drivers/choose_driver.py +++ b/fastNLP/core/drivers/choose_driver.py @@ -1,6 +1,7 @@ from typing import Union, Optional, List from .driver import Driver +from ..utils import is_torch_module, is_paddle_module, is_jittor_module def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: @@ -17,6 +18,16 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, if isinstance(driver, Driver): return driver + if driver == "auto": + if is_torch_module(model): + driver = "torch" + elif is_paddle_module(model): + driver = "paddle" + elif is_jittor_module(model): + driver = "jittor" + else: + raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") + if driver in {"torch", "fairscale"}: from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver return initialize_torch_driver(driver, device, model, **kwargs) diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 0857f450..2825b5ac 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -1,5 +1,6 @@ __all__ = [ 'cache_results', + 'is_jittor_module', 'is_jittor_dataset', 'jittor_collate_wraps', 'paddle_to', @@ -9,8 +10,10 @@ __all__ = [ 'is_in_paddle_dist', 'is_in_fnlp_paddle_dist', 'is_in_paddle_launch_dist', + 'is_paddle_module', 'f_rich_progress', 'torch_move_data_to_device', + 'is_torch_module', 'get_fn_arg_names', 'auto_param_call', 'check_user_specific_params', @@ -28,11 +31,11 @@ __all__ = [ ] from .cache_results import cache_results -from .jittor_utils import is_jittor_dataset, jittor_collate_wraps +from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ - is_in_fnlp_paddle_dist, is_in_paddle_launch_dist + is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module from .rich_progress import f_rich_progress -from .torch_utils import torch_move_data_to_device +from .torch_utils import torch_move_data_to_device, is_torch_module from .utils import * from .tqdm_progress import f_tqdm_progress from .seq_len_to_mask import seq_len_to_mask diff --git a/fastNLP/core/utils/jittor_utils.py b/fastNLP/core/utils/jittor_utils.py index f29b1f46..ac00cd22 100644 --- a/fastNLP/core/utils/jittor_utils.py +++ b/fastNLP/core/utils/jittor_utils.py @@ -1,6 +1,7 @@ __all__ = [ + 'is_jittor_module', 'is_jittor_dataset', - 'jittor_collate_wraps' + 'jittor_collate_wraps', ] from collections.abc import Mapping, Callable @@ -13,6 +14,17 @@ if _NEED_IMPORT_JITTOR: from fastNLP.core.dataset import Instance +def is_jittor_module(model) -> bool: + """ + 判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型 + + :param model: 模型; + :return: 当前模型是否为 ``jittor`` 的模型; + """ + try: + return isinstance(model, jt.Module) + except BaseException: + return False def is_jittor_dataset(dataset) -> bool: """ diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index 9e7e73a4..adcbcabd 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -6,6 +6,7 @@ __all__ = [ "is_in_paddle_dist", "is_in_fnlp_paddle_dist", "is_in_paddle_launch_dist", + "is_paddle_module", ] import os @@ -174,4 +175,16 @@ def is_in_paddle_launch_dist() -> bool: """ 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 """ - return FASTNLP_BACKEND_LAUNCH in os.environ \ No newline at end of file + return FASTNLP_BACKEND_LAUNCH in os.environ + +def is_paddle_module(model) -> bool: + """ + 判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型 + + :param model: 模型; + :return: 当前模型是否为 ``paddle`` 的模型; + """ + try: + return isinstance(model, paddle.nn.Layer) + except BaseException: + return False \ No newline at end of file diff --git a/fastNLP/core/utils/torch_utils.py b/fastNLP/core/utils/torch_utils.py index 0cef2205..c58715b8 100644 --- a/fastNLP/core/utils/torch_utils.py +++ b/fastNLP/core/utils/torch_utils.py @@ -8,7 +8,8 @@ if _NEED_IMPORT_TORCH: DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD __all__ = [ - 'torch_move_data_to_device' + 'torch_move_data_to_device', + 'is_torch_module', ] from .utils import apply_to_collection @@ -64,3 +65,15 @@ def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.dev dtype = TorchTransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) + +def is_torch_module(model) -> bool: + """ + 判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型 + + :param model: 模型; + :return: 当前模型是否为 ``torch`` 的模型; + """ + try: + return isinstance(model, torch.nn.Module) + except BaseException: + return False \ No newline at end of file