Browse Source

为 Trainer 的driver 参数增加 'auto' 选项

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
78596ea11c
6 changed files with 64 additions and 11 deletions
  1. +6
    -5
      fastNLP/core/controllers/trainer.py
  2. +11
    -0
      fastNLP/core/drivers/choose_driver.py
  3. +6
    -3
      fastNLP/core/utils/__init__.py
  4. +13
    -1
      fastNLP/core/utils/jittor_utils.py
  5. +14
    -1
      fastNLP/core/utils/paddle_utils.py
  6. +14
    -1
      fastNLP/core/utils/torch_utils.py

+ 6
- 5
fastNLP/core/controllers/trainer.py View File

@@ -55,9 +55,10 @@ class Trainer(TrainerEventTrigger):
您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device``
应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。


:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等
国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device``
的设置;
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["auto", "torch", "paddle", "jittor", "fairscale"]。其值为 ``"auto"`` 时,
**FastNLP** 会根据传入模型的类型自行判断使用哪一种模式;其值为 "torch" 时,表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``;
其值为 "paddle" 时,表示使用 ``PaddleSingleDriver`` 或者 ``PaddleFleetDriver``;其值为 "jittor" 时,表示使用 ``JittorSingleDriver``
或者 ``JittorMPIDriver``;其值为 "fairscale" 时,表示使用 ``FairScaleDriver``。在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置;


.. warning:: .. warning::


@@ -81,7 +82,7 @@ class Trainer(TrainerEventTrigger):


device 的可选输入如下所示: device 的可选输入如下所示:


* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等;
* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'gpu:0' 等;
* *torch.device*: 例如 'torch.device("cuda:0")'; * *torch.device*: 例如 'torch.device("cuda:0")';
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`;
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值;
@@ -365,9 +366,9 @@ class Trainer(TrainerEventTrigger):
def __init__( def __init__(
self, self,
model, model,
driver,
train_dataloader, train_dataloader,
optimizers, optimizers,
driver: str = "auto",
device: Optional[Union[int, List[int], str]] = "cpu", device: Optional[Union[int, List[int], str]] = "cpu",
n_epochs: int = 20, n_epochs: int = 20,
evaluate_dataloaders=None, evaluate_dataloaders=None,


+ 11
- 0
fastNLP/core/drivers/choose_driver.py View File

@@ -1,6 +1,7 @@
from typing import Union, Optional, List from typing import Union, Optional, List


from .driver import Driver from .driver import Driver
from ..utils import is_torch_module, is_paddle_module, is_jittor_module




def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver:
@@ -17,6 +18,16 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
if isinstance(driver, Driver): if isinstance(driver, Driver):
return driver return driver


if driver == "auto":
if is_torch_module(model):
driver = "torch"
elif is_paddle_module(model):
driver = "paddle"
elif is_jittor_module(model):
driver = "jittor"
else:
raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.")

if driver in {"torch", "fairscale"}: if driver in {"torch", "fairscale"}:
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
return initialize_torch_driver(driver, device, model, **kwargs) return initialize_torch_driver(driver, device, model, **kwargs)


+ 6
- 3
fastNLP/core/utils/__init__.py View File

@@ -1,5 +1,6 @@
__all__ = [ __all__ = [
'cache_results', 'cache_results',
'is_jittor_module',
'is_jittor_dataset', 'is_jittor_dataset',
'jittor_collate_wraps', 'jittor_collate_wraps',
'paddle_to', 'paddle_to',
@@ -9,8 +10,10 @@ __all__ = [
'is_in_paddle_dist', 'is_in_paddle_dist',
'is_in_fnlp_paddle_dist', 'is_in_fnlp_paddle_dist',
'is_in_paddle_launch_dist', 'is_in_paddle_launch_dist',
'is_paddle_module',
'f_rich_progress', 'f_rich_progress',
'torch_move_data_to_device', 'torch_move_data_to_device',
'is_torch_module',
'get_fn_arg_names', 'get_fn_arg_names',
'auto_param_call', 'auto_param_call',
'check_user_specific_params', 'check_user_specific_params',
@@ -28,11 +31,11 @@ __all__ = [
] ]


from .cache_results import cache_results from .cache_results import cache_results
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module
from .rich_progress import f_rich_progress from .rich_progress import f_rich_progress
from .torch_utils import torch_move_data_to_device
from .torch_utils import torch_move_data_to_device, is_torch_module
from .utils import * from .utils import *
from .tqdm_progress import f_tqdm_progress from .tqdm_progress import f_tqdm_progress
from .seq_len_to_mask import seq_len_to_mask from .seq_len_to_mask import seq_len_to_mask


+ 13
- 1
fastNLP/core/utils/jittor_utils.py View File

@@ -1,6 +1,7 @@
__all__ = [ __all__ = [
'is_jittor_module',
'is_jittor_dataset', 'is_jittor_dataset',
'jittor_collate_wraps'
'jittor_collate_wraps',
] ]


from collections.abc import Mapping, Callable from collections.abc import Mapping, Callable
@@ -13,6 +14,17 @@ if _NEED_IMPORT_JITTOR:


from fastNLP.core.dataset import Instance from fastNLP.core.dataset import Instance


def is_jittor_module(model) -> bool:
"""
判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型

:param model: 模型;
:return: 当前模型是否为 ``jittor`` 的模型;
"""
try:
return isinstance(model, jt.Module)
except BaseException:
return False


def is_jittor_dataset(dataset) -> bool: def is_jittor_dataset(dataset) -> bool:
""" """


+ 14
- 1
fastNLP/core/utils/paddle_utils.py View File

@@ -6,6 +6,7 @@ __all__ = [
"is_in_paddle_dist", "is_in_paddle_dist",
"is_in_fnlp_paddle_dist", "is_in_fnlp_paddle_dist",
"is_in_paddle_launch_dist", "is_in_paddle_launch_dist",
"is_paddle_module",
] ]


import os import os
@@ -174,4 +175,16 @@ def is_in_paddle_launch_dist() -> bool:
""" """
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中
""" """
return FASTNLP_BACKEND_LAUNCH in os.environ
return FASTNLP_BACKEND_LAUNCH in os.environ

def is_paddle_module(model) -> bool:
"""
判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型

:param model: 模型;
:return: 当前模型是否为 ``paddle`` 的模型;
"""
try:
return isinstance(model, paddle.nn.Layer)
except BaseException:
return False

+ 14
- 1
fastNLP/core/utils/torch_utils.py View File

@@ -8,7 +8,8 @@ if _NEED_IMPORT_TORCH:
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD


__all__ = [ __all__ = [
'torch_move_data_to_device'
'torch_move_data_to_device',
'is_torch_module',
] ]


from .utils import apply_to_collection from .utils import apply_to_collection
@@ -64,3 +65,15 @@ def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.dev


dtype = TorchTransferableDataType dtype = TorchTransferableDataType
return apply_to_collection(batch, dtype=dtype, function=batch_to) return apply_to_collection(batch, dtype=dtype, function=batch_to)

def is_torch_module(model) -> bool:
"""
判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型

:param model: 模型;
:return: 当前模型是否为 ``torch`` 的模型;
"""
try:
return isinstance(model, torch.nn.Module)
except BaseException:
return False

Loading…
Cancel
Save