| @@ -55,9 +55,10 @@ class Trainer(TrainerEventTrigger): | |||||
| 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device`` | ||||
| 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。 | ||||
| :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等 | |||||
| 国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` | |||||
| 的设置; | |||||
| :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["auto", "torch", "paddle", "jittor", "fairscale"]。其值为 ``"auto"`` 时, | |||||
| **FastNLP** 会根据传入模型的类型自行判断使用哪一种模式;其值为 "torch" 时,表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``; | |||||
| 其值为 "paddle" 时,表示使用 ``PaddleSingleDriver`` 或者 ``PaddleFleetDriver``;其值为 "jittor" 时,表示使用 ``JittorSingleDriver`` | |||||
| 或者 ``JittorMPIDriver``;其值为 "fairscale" 时,表示使用 ``FairScaleDriver``。在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | |||||
| .. warning:: | .. warning:: | ||||
| @@ -81,7 +82,7 @@ class Trainer(TrainerEventTrigger): | |||||
| device 的可选输入如下所示: | device 的可选输入如下所示: | ||||
| * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; | |||||
| * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'gpu:0' 等; | |||||
| * *torch.device*: 例如 'torch.device("cuda:0")'; | * *torch.device*: 例如 'torch.device("cuda:0")'; | ||||
| * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | ||||
| * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | ||||
| @@ -365,9 +366,9 @@ class Trainer(TrainerEventTrigger): | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| model, | model, | ||||
| driver, | |||||
| train_dataloader, | train_dataloader, | ||||
| optimizers, | optimizers, | ||||
| driver: str = "auto", | |||||
| device: Optional[Union[int, List[int], str]] = "cpu", | device: Optional[Union[int, List[int], str]] = "cpu", | ||||
| n_epochs: int = 20, | n_epochs: int = 20, | ||||
| evaluate_dataloaders=None, | evaluate_dataloaders=None, | ||||
| @@ -1,6 +1,7 @@ | |||||
| from typing import Union, Optional, List | from typing import Union, Optional, List | ||||
| from .driver import Driver | from .driver import Driver | ||||
| from ..utils import is_torch_module, is_paddle_module, is_jittor_module | |||||
| def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | ||||
| @@ -17,6 +18,16 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
| if isinstance(driver, Driver): | if isinstance(driver, Driver): | ||||
| return driver | return driver | ||||
| if driver == "auto": | |||||
| if is_torch_module(model): | |||||
| driver = "torch" | |||||
| elif is_paddle_module(model): | |||||
| driver = "paddle" | |||||
| elif is_jittor_module(model): | |||||
| driver = "jittor" | |||||
| else: | |||||
| raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") | |||||
| if driver in {"torch", "fairscale"}: | if driver in {"torch", "fairscale"}: | ||||
| from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
| return initialize_torch_driver(driver, device, model, **kwargs) | return initialize_torch_driver(driver, device, model, **kwargs) | ||||
| @@ -1,5 +1,6 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'cache_results', | 'cache_results', | ||||
| 'is_jittor_module', | |||||
| 'is_jittor_dataset', | 'is_jittor_dataset', | ||||
| 'jittor_collate_wraps', | 'jittor_collate_wraps', | ||||
| 'paddle_to', | 'paddle_to', | ||||
| @@ -9,8 +10,10 @@ __all__ = [ | |||||
| 'is_in_paddle_dist', | 'is_in_paddle_dist', | ||||
| 'is_in_fnlp_paddle_dist', | 'is_in_fnlp_paddle_dist', | ||||
| 'is_in_paddle_launch_dist', | 'is_in_paddle_launch_dist', | ||||
| 'is_paddle_module', | |||||
| 'f_rich_progress', | 'f_rich_progress', | ||||
| 'torch_move_data_to_device', | 'torch_move_data_to_device', | ||||
| 'is_torch_module', | |||||
| 'get_fn_arg_names', | 'get_fn_arg_names', | ||||
| 'auto_param_call', | 'auto_param_call', | ||||
| 'check_user_specific_params', | 'check_user_specific_params', | ||||
| @@ -28,11 +31,11 @@ __all__ = [ | |||||
| ] | ] | ||||
| from .cache_results import cache_results | from .cache_results import cache_results | ||||
| from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||||
| from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module | |||||
| from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | ||||
| is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||||
| is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module | |||||
| from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
| from .torch_utils import torch_move_data_to_device | |||||
| from .torch_utils import torch_move_data_to_device, is_torch_module | |||||
| from .utils import * | from .utils import * | ||||
| from .tqdm_progress import f_tqdm_progress | from .tqdm_progress import f_tqdm_progress | ||||
| from .seq_len_to_mask import seq_len_to_mask | from .seq_len_to_mask import seq_len_to_mask | ||||
| @@ -1,6 +1,7 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'is_jittor_module', | |||||
| 'is_jittor_dataset', | 'is_jittor_dataset', | ||||
| 'jittor_collate_wraps' | |||||
| 'jittor_collate_wraps', | |||||
| ] | ] | ||||
| from collections.abc import Mapping, Callable | from collections.abc import Mapping, Callable | ||||
| @@ -13,6 +14,17 @@ if _NEED_IMPORT_JITTOR: | |||||
| from fastNLP.core.dataset import Instance | from fastNLP.core.dataset import Instance | ||||
| def is_jittor_module(model) -> bool: | |||||
| """ | |||||
| 判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型 | |||||
| :param model: 模型; | |||||
| :return: 当前模型是否为 ``jittor`` 的模型; | |||||
| """ | |||||
| try: | |||||
| return isinstance(model, jt.Module) | |||||
| except BaseException: | |||||
| return False | |||||
| def is_jittor_dataset(dataset) -> bool: | def is_jittor_dataset(dataset) -> bool: | ||||
| """ | """ | ||||
| @@ -6,6 +6,7 @@ __all__ = [ | |||||
| "is_in_paddle_dist", | "is_in_paddle_dist", | ||||
| "is_in_fnlp_paddle_dist", | "is_in_fnlp_paddle_dist", | ||||
| "is_in_paddle_launch_dist", | "is_in_paddle_launch_dist", | ||||
| "is_paddle_module", | |||||
| ] | ] | ||||
| import os | import os | ||||
| @@ -174,4 +175,16 @@ def is_in_paddle_launch_dist() -> bool: | |||||
| """ | """ | ||||
| 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | ||||
| """ | """ | ||||
| return FASTNLP_BACKEND_LAUNCH in os.environ | |||||
| return FASTNLP_BACKEND_LAUNCH in os.environ | |||||
| def is_paddle_module(model) -> bool: | |||||
| """ | |||||
| 判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型 | |||||
| :param model: 模型; | |||||
| :return: 当前模型是否为 ``paddle`` 的模型; | |||||
| """ | |||||
| try: | |||||
| return isinstance(model, paddle.nn.Layer) | |||||
| except BaseException: | |||||
| return False | |||||
| @@ -8,7 +8,8 @@ if _NEED_IMPORT_TORCH: | |||||
| DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | ||||
| __all__ = [ | __all__ = [ | ||||
| 'torch_move_data_to_device' | |||||
| 'torch_move_data_to_device', | |||||
| 'is_torch_module', | |||||
| ] | ] | ||||
| from .utils import apply_to_collection | from .utils import apply_to_collection | ||||
| @@ -64,3 +65,15 @@ def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.dev | |||||
| dtype = TorchTransferableDataType | dtype = TorchTransferableDataType | ||||
| return apply_to_collection(batch, dtype=dtype, function=batch_to) | return apply_to_collection(batch, dtype=dtype, function=batch_to) | ||||
| def is_torch_module(model) -> bool: | |||||
| """ | |||||
| 判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型 | |||||
| :param model: 模型; | |||||
| :return: 当前模型是否为 ``torch`` 的模型; | |||||
| """ | |||||
| try: | |||||
| return isinstance(model, torch.nn.Module) | |||||
| except BaseException: | |||||
| return False | |||||