@@ -55,9 +55,10 @@ class Trainer(TrainerEventTrigger): | |||
您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | |||
应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | |||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等 | |||
国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` | |||
的设置; | |||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["auto", "torch", "paddle", "jittor", "fairscale"]。其值为 ``"auto"`` 时, | |||
**FastNLP** 会根据传入模型的类型自行判断使用哪一种模式;其值为 "torch" 时,表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``; | |||
其值为 "paddle" 时,表示使用 ``PaddleSingleDriver`` 或者 ``PaddleFleetDriver``;其值为 "jittor" 时,表示使用 ``JittorSingleDriver`` | |||
或者 ``JittorMPIDriver``;其值为 "fairscale" 时,表示使用 ``FairScaleDriver``。在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | |||
.. warning:: | |||
@@ -81,7 +82,7 @@ class Trainer(TrainerEventTrigger): | |||
device 的可选输入如下所示: | |||
* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; | |||
* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'gpu:0' 等; | |||
* *torch.device*: 例如 'torch.device("cuda:0")'; | |||
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | |||
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | |||
@@ -365,9 +366,9 @@ class Trainer(TrainerEventTrigger): | |||
def __init__( | |||
self, | |||
model, | |||
driver, | |||
train_dataloader, | |||
optimizers, | |||
driver: str = "auto", | |||
device: Optional[Union[int, List[int], str]] = "cpu", | |||
n_epochs: int = 20, | |||
evaluate_dataloaders=None, | |||
@@ -1,6 +1,7 @@ | |||
from typing import Union, Optional, List | |||
from .driver import Driver | |||
from ..utils import is_torch_module, is_paddle_module, is_jittor_module | |||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | |||
@@ -17,6 +18,16 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||
if isinstance(driver, Driver): | |||
return driver | |||
if driver == "auto": | |||
if is_torch_module(model): | |||
driver = "torch" | |||
elif is_paddle_module(model): | |||
driver = "paddle" | |||
elif is_jittor_module(model): | |||
driver = "jittor" | |||
else: | |||
raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") | |||
if driver in {"torch", "fairscale"}: | |||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||
return initialize_torch_driver(driver, device, model, **kwargs) | |||
@@ -1,5 +1,6 @@ | |||
__all__ = [ | |||
'cache_results', | |||
'is_jittor_module', | |||
'is_jittor_dataset', | |||
'jittor_collate_wraps', | |||
'paddle_to', | |||
@@ -9,8 +10,10 @@ __all__ = [ | |||
'is_in_paddle_dist', | |||
'is_in_fnlp_paddle_dist', | |||
'is_in_paddle_launch_dist', | |||
'is_paddle_module', | |||
'f_rich_progress', | |||
'torch_move_data_to_device', | |||
'is_torch_module', | |||
'get_fn_arg_names', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
@@ -28,11 +31,11 @@ __all__ = [ | |||
] | |||
from .cache_results import cache_results | |||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module | |||
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | |||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module | |||
from .rich_progress import f_rich_progress | |||
from .torch_utils import torch_move_data_to_device | |||
from .torch_utils import torch_move_data_to_device, is_torch_module | |||
from .utils import * | |||
from .tqdm_progress import f_tqdm_progress | |||
from .seq_len_to_mask import seq_len_to_mask | |||
@@ -1,6 +1,7 @@ | |||
__all__ = [ | |||
'is_jittor_module', | |||
'is_jittor_dataset', | |||
'jittor_collate_wraps' | |||
'jittor_collate_wraps', | |||
] | |||
from collections.abc import Mapping, Callable | |||
@@ -13,6 +14,17 @@ if _NEED_IMPORT_JITTOR: | |||
from fastNLP.core.dataset import Instance | |||
def is_jittor_module(model) -> bool: | |||
""" | |||
判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型 | |||
:param model: 模型; | |||
:return: 当前模型是否为 ``jittor`` 的模型; | |||
""" | |||
try: | |||
return isinstance(model, jt.Module) | |||
except BaseException: | |||
return False | |||
def is_jittor_dataset(dataset) -> bool: | |||
""" | |||
@@ -6,6 +6,7 @@ __all__ = [ | |||
"is_in_paddle_dist", | |||
"is_in_fnlp_paddle_dist", | |||
"is_in_paddle_launch_dist", | |||
"is_paddle_module", | |||
] | |||
import os | |||
@@ -174,4 +175,16 @@ def is_in_paddle_launch_dist() -> bool: | |||
""" | |||
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | |||
""" | |||
return FASTNLP_BACKEND_LAUNCH in os.environ | |||
return FASTNLP_BACKEND_LAUNCH in os.environ | |||
def is_paddle_module(model) -> bool: | |||
""" | |||
判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型 | |||
:param model: 模型; | |||
:return: 当前模型是否为 ``paddle`` 的模型; | |||
""" | |||
try: | |||
return isinstance(model, paddle.nn.Layer) | |||
except BaseException: | |||
return False |
@@ -8,7 +8,8 @@ if _NEED_IMPORT_TORCH: | |||
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | |||
__all__ = [ | |||
'torch_move_data_to_device' | |||
'torch_move_data_to_device', | |||
'is_torch_module', | |||
] | |||
from .utils import apply_to_collection | |||
@@ -64,3 +65,15 @@ def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.dev | |||
dtype = TorchTransferableDataType | |||
return apply_to_collection(batch, dtype=dtype, function=batch_to) | |||
def is_torch_module(model) -> bool: | |||
""" | |||
判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型 | |||
:param model: 模型; | |||
:return: 当前模型是否为 ``torch`` 的模型; | |||
""" | |||
try: | |||
return isinstance(model, torch.nn.Module) | |||
except BaseException: | |||
return False |