@@ -55,9 +55,10 @@ class Trainer(TrainerEventTrigger): | |||||
您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | ||||
应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | ||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等 | |||||
国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` | |||||
的设置; | |||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["auto", "torch", "paddle", "jittor", "fairscale"]。其值为 ``"auto"`` 时, | |||||
**FastNLP** 会根据传入模型的类型自行判断使用哪一种模式;其值为 "torch" 时,表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``; | |||||
其值为 "paddle" 时,表示使用 ``PaddleSingleDriver`` 或者 ``PaddleFleetDriver``;其值为 "jittor" 时,表示使用 ``JittorSingleDriver`` | |||||
或者 ``JittorMPIDriver``;其值为 "fairscale" 时,表示使用 ``FairScaleDriver``。在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | |||||
.. warning:: | .. warning:: | ||||
@@ -81,7 +82,7 @@ class Trainer(TrainerEventTrigger): | |||||
device 的可选输入如下所示: | device 的可选输入如下所示: | ||||
* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; | |||||
* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'gpu:0' 等; | |||||
* *torch.device*: 例如 'torch.device("cuda:0")'; | * *torch.device*: 例如 'torch.device("cuda:0")'; | ||||
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | ||||
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | ||||
@@ -365,9 +366,9 @@ class Trainer(TrainerEventTrigger): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
model, | model, | ||||
driver, | |||||
train_dataloader, | train_dataloader, | ||||
optimizers, | optimizers, | ||||
driver: str = "auto", | |||||
device: Optional[Union[int, List[int], str]] = "cpu", | device: Optional[Union[int, List[int], str]] = "cpu", | ||||
n_epochs: int = 20, | n_epochs: int = 20, | ||||
evaluate_dataloaders=None, | evaluate_dataloaders=None, | ||||
@@ -1,6 +1,7 @@ | |||||
from typing import Union, Optional, List | from typing import Union, Optional, List | ||||
from .driver import Driver | from .driver import Driver | ||||
from ..utils import is_torch_module, is_paddle_module, is_jittor_module | |||||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | ||||
@@ -17,6 +18,16 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
if isinstance(driver, Driver): | if isinstance(driver, Driver): | ||||
return driver | return driver | ||||
if driver == "auto": | |||||
if is_torch_module(model): | |||||
driver = "torch" | |||||
elif is_paddle_module(model): | |||||
driver = "paddle" | |||||
elif is_jittor_module(model): | |||||
driver = "jittor" | |||||
else: | |||||
raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") | |||||
if driver in {"torch", "fairscale"}: | if driver in {"torch", "fairscale"}: | ||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
return initialize_torch_driver(driver, device, model, **kwargs) | return initialize_torch_driver(driver, device, model, **kwargs) | ||||
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'cache_results', | 'cache_results', | ||||
'is_jittor_module', | |||||
'is_jittor_dataset', | 'is_jittor_dataset', | ||||
'jittor_collate_wraps', | 'jittor_collate_wraps', | ||||
'paddle_to', | 'paddle_to', | ||||
@@ -9,8 +10,10 @@ __all__ = [ | |||||
'is_in_paddle_dist', | 'is_in_paddle_dist', | ||||
'is_in_fnlp_paddle_dist', | 'is_in_fnlp_paddle_dist', | ||||
'is_in_paddle_launch_dist', | 'is_in_paddle_launch_dist', | ||||
'is_paddle_module', | |||||
'f_rich_progress', | 'f_rich_progress', | ||||
'torch_move_data_to_device', | 'torch_move_data_to_device', | ||||
'is_torch_module', | |||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
'auto_param_call', | 'auto_param_call', | ||||
'check_user_specific_params', | 'check_user_specific_params', | ||||
@@ -28,11 +31,11 @@ __all__ = [ | |||||
] | ] | ||||
from .cache_results import cache_results | from .cache_results import cache_results | ||||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module | |||||
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | ||||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module | |||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
from .torch_utils import torch_move_data_to_device | |||||
from .torch_utils import torch_move_data_to_device, is_torch_module | |||||
from .utils import * | from .utils import * | ||||
from .tqdm_progress import f_tqdm_progress | from .tqdm_progress import f_tqdm_progress | ||||
from .seq_len_to_mask import seq_len_to_mask | from .seq_len_to_mask import seq_len_to_mask | ||||
@@ -1,6 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'is_jittor_module', | |||||
'is_jittor_dataset', | 'is_jittor_dataset', | ||||
'jittor_collate_wraps' | |||||
'jittor_collate_wraps', | |||||
] | ] | ||||
from collections.abc import Mapping, Callable | from collections.abc import Mapping, Callable | ||||
@@ -13,6 +14,17 @@ if _NEED_IMPORT_JITTOR: | |||||
from fastNLP.core.dataset import Instance | from fastNLP.core.dataset import Instance | ||||
def is_jittor_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``jittor`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, jt.Module) | |||||
except BaseException: | |||||
return False | |||||
def is_jittor_dataset(dataset) -> bool: | def is_jittor_dataset(dataset) -> bool: | ||||
""" | """ | ||||
@@ -6,6 +6,7 @@ __all__ = [ | |||||
"is_in_paddle_dist", | "is_in_paddle_dist", | ||||
"is_in_fnlp_paddle_dist", | "is_in_fnlp_paddle_dist", | ||||
"is_in_paddle_launch_dist", | "is_in_paddle_launch_dist", | ||||
"is_paddle_module", | |||||
] | ] | ||||
import os | import os | ||||
@@ -174,4 +175,16 @@ def is_in_paddle_launch_dist() -> bool: | |||||
""" | """ | ||||
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | ||||
""" | """ | ||||
return FASTNLP_BACKEND_LAUNCH in os.environ | |||||
return FASTNLP_BACKEND_LAUNCH in os.environ | |||||
def is_paddle_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``paddle`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, paddle.nn.Layer) | |||||
except BaseException: | |||||
return False |
@@ -8,7 +8,8 @@ if _NEED_IMPORT_TORCH: | |||||
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | ||||
__all__ = [ | __all__ = [ | ||||
'torch_move_data_to_device' | |||||
'torch_move_data_to_device', | |||||
'is_torch_module', | |||||
] | ] | ||||
from .utils import apply_to_collection | from .utils import apply_to_collection | ||||
@@ -64,3 +65,15 @@ def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.dev | |||||
dtype = TorchTransferableDataType | dtype = TorchTransferableDataType | ||||
return apply_to_collection(batch, dtype=dtype, function=batch_to) | return apply_to_collection(batch, dtype=dtype, function=batch_to) | ||||
def is_torch_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``torch`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, torch.nn.Module) | |||||
except BaseException: | |||||
return False |