| @@ -42,7 +42,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
| # 优先级 user > cuda | # 优先级 user > cuda | ||||
| # 判断单机情况 device 的合法性 | # 判断单机情况 device 的合法性 | ||||
| # 分布式情况下通过 world_device 判断 | # 分布式情况下通过 world_device 判断 | ||||
| if user_visible_devices is not None: | |||||
| if user_visible_devices != "": | |||||
| _could_use_device_num = len(user_visible_devices.split(",")) | _could_use_device_num = len(user_visible_devices.split(",")) | ||||
| elif cuda_visible_devices is not None: | elif cuda_visible_devices is not None: | ||||
| _could_use_device_num = len(cuda_visible_devices.split(",")) | _could_use_device_num = len(cuda_visible_devices.split(",")) | ||||
| @@ -51,8 +51,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
| if isinstance(device, int): | if isinstance(device, int): | ||||
| if device < 0 and device != -1: | if device < 0 and device != -1: | ||||
| raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
| if device >= _could_use_device_num: | |||||
| raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
| # if device >= _could_use_device_num: | |||||
| # raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
| device = f"gpu:{device}" | device = f"gpu:{device}" | ||||
| elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
| device = list(set(device)) | device = list(set(device)) | ||||
| @@ -1,8 +1,14 @@ | |||||
| import os | |||||
| from typing import Optional, Dict, Union | from typing import Optional, Dict, Union | ||||
| from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
| from fastNLP.core.utils import auto_param_call, get_paddle_gpu_str | |||||
| from fastNLP.core.utils import ( | |||||
| auto_param_call, | |||||
| get_paddle_gpu_str, | |||||
| get_paddle_device_id, | |||||
| paddle_move_data_to_device, | |||||
| ) | |||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| @@ -86,8 +92,9 @@ class PaddleSingleDriver(PaddleDriver): | |||||
| self._test_signature_fn = model.forward | self._test_signature_fn = model.forward | ||||
| def setup(self): | def setup(self): | ||||
| paddle.device.set_device(self.model_device) | |||||
| self.model.to(self.model_device) | |||||
| os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device)) | |||||
| paddle.device.set_device("gpu:0") | |||||
| self.model.to("gpu:0") | |||||
| def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
| # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
| @@ -116,6 +123,16 @@ class PaddleSingleDriver(PaddleDriver): | |||||
| else: | else: | ||||
| return self._test_step(batch) | return self._test_step(batch) | ||||
| def move_data_to_device(self, batch: 'paddle.Tensor'): | |||||
| r""" | |||||
| 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||||
| 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||||
| 在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` | |||||
| :return: 将移动到指定机器上的 batch 对象返回; | |||||
| """ | |||||
| return paddle_move_data_to_device(batch, "gpu:0") | |||||
| def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | ||||
| # 暂时不支持IteratorDataset | # 暂时不支持IteratorDataset | ||||
| assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
| @@ -272,7 +272,7 @@ def get_device_from_visible(device: Union[str, int]): | |||||
| else: | else: | ||||
| # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | ||||
| user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | ||||
| if user_visiblde_devices is None or user_visiblde_devices != "": | |||||
| if user_visiblde_devices is not None and user_visiblde_devices != "": | |||||
| # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | ||||
| idx = user_visiblde_devices.split(",")[idx] | idx = user_visiblde_devices.split(",")[idx] | ||||
| else: | else: | ||||
| @@ -11,11 +11,12 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| import paddle.distributed as dist | |||||
| from paddle.fluid.dygraph import parallel_helper | from paddle.fluid.dygraph import parallel_helper | ||||
| def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | ||||
| gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | ||||
| paddle.distributed.all_gather(gathered_result, result, group) | |||||
| dist.all_gather(gathered_result, result, group) | |||||
| return gathered_result | return gathered_result | ||||
| class PaddleBackend(Backend): | class PaddleBackend(Backend): | ||||
| @@ -36,13 +37,13 @@ class PaddleBackend(Backend): | |||||
| tensor = paddle.stack(tensor) | tensor = paddle.stack(tensor) | ||||
| # 第一步, aggregate结果 | # 第一步, aggregate结果 | ||||
| if method == 'sum': | if method == 'sum': | ||||
| tensor = paddle.sum(tensor, dim=0) | |||||
| tensor = paddle.sum(tensor, axis=0) | |||||
| elif method == 'mean': | elif method == 'mean': | ||||
| tensor = paddle.mean(tensor, dim=0) | |||||
| tensor = paddle.mean(tensor, axis=0) | |||||
| elif method == 'max': | elif method == 'max': | ||||
| tensor, _ = paddle.max(tensor, dim=0) | |||||
| tensor, _ = paddle.max(tensor, axis=0) | |||||
| elif method == 'min': | elif method == 'min': | ||||
| tensor, _ = paddle.min(tensor, dim=0) | |||||
| tensor, _ = paddle.min(tensor, axis=0) | |||||
| else: | else: | ||||
| raise AggregateMethodError(should_have_aggregate_method=False) | raise AggregateMethodError(should_have_aggregate_method=False) | ||||
| @@ -80,11 +81,12 @@ class PaddleBackend(Backend): | |||||
| 聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding | 聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding | ||||
| """ | """ | ||||
| # TODO check 正确性 | # TODO check 正确性 | ||||
| if group is None: | |||||
| group = paddle.distributed.get_group(0) | |||||
| # 有 paddle 那边的 bug,2.3 版本的时候修复了,到时候改一下 | |||||
| # if group is None: | |||||
| # group = dist.get_group(0) | |||||
| world_size = group.nranks | |||||
| paddle.distributed.barrier(group=group) | |||||
| world_size = group.nranks if group is not None else dist.get_world_size() | |||||
| dist.barrier(group=group) | |||||
| # 张量为 标量的情况,简单地gather就好 | # 张量为 标量的情况,简单地gather就好 | ||||
| if result.ndim == 0: | if result.ndim == 0: | ||||
| @@ -93,10 +95,10 @@ class PaddleBackend(Backend): | |||||
| # 获得 result 的 shape | # 获得 result 的 shape | ||||
| local_size = paddle.to_tensor(result.shape) | local_size = paddle.to_tensor(result.shape) | ||||
| # 将 group 中所有 result 的大小聚合在一起 | # 将 group 中所有 result 的大小聚合在一起 | ||||
| local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] | |||||
| paddle.distributed.all_gather(local_sizes, local_size, group=group) | |||||
| local_sizes = [] | |||||
| dist.all_gather(local_sizes, local_size, group=group) | |||||
| # 堆叠后,计算出 shape 每一维度的最大值 | # 堆叠后,计算出 shape 每一维度的最大值 | ||||
| max_size = paddle.stack(local_sizes).max(axis=0).values | |||||
| max_size = paddle.stack(local_sizes).max(axis=0) | |||||
| all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | ||||
| # 如果所有的结果大小相同,那么可以直接聚合 | # 如果所有的结果大小相同,那么可以直接聚合 | ||||
| @@ -111,16 +113,15 @@ class PaddleBackend(Backend): | |||||
| pad_dims.append(val.item()) | pad_dims.append(val.item()) | ||||
| result_padded = paddle.nn.functional.pad(result, pad_dims) | result_padded = paddle.nn.functional.pad(result, pad_dims) | ||||
| # 重新进行聚合 | # 重新进行聚合 | ||||
| gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)] | |||||
| paddle.distributed.all_gather(gathered_result, result_padded, group) | |||||
| gathered_result = [] | |||||
| dist.all_gather(gathered_result, result_padded, group) | |||||
| for idx, item_size in enumerate(local_sizes): | for idx, item_size in enumerate(local_sizes): | ||||
| slice_param = [slice(dim_size) for dim_size in item_size] | |||||
| slice_param = [slice(dim_size) for dim_size in item_size.tolist()] | |||||
| gathered_result[idx] = gathered_result[idx][slice_param] | gathered_result[idx] = gathered_result[idx][slice_param] | ||||
| return gathered_result | return gathered_result | ||||
| def move_tensor_to_device(self, tensor, device): | def move_tensor_to_device(self, tensor, device): | ||||
| # TODO 如果在这里处理的话,会不会在别的地方引起bug? | # TODO 如果在这里处理的话,会不会在别的地方引起bug? | ||||
| if is_in_paddle_dist(): | |||||
| device = get_device_from_visible(device) | |||||
| device = get_device_from_visible(device) | |||||
| return paddle_to(tensor, device) | return paddle_to(tensor, device) | ||||
| @@ -4,17 +4,18 @@ __all__ = [ | |||||
| from typing import Any | from typing import Any | ||||
| from functools import wraps | from functools import wraps | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
| from fastNLP.envs.utils import _module_available | from fastNLP.envs.utils import _module_available | ||||
| _IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | _IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | ||||
| if _IS_TORCHMETRICS_AVAILABLE: | |||||
| from torchmetrics import Metric as torchmetrics_Metric | |||||
| _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | ||||
| if _IS_ALLENNLP_AVAILABLE: | if _IS_ALLENNLP_AVAILABLE: | ||||
| from allennlp.training.metrics import Metric as allennlp_Metric | from allennlp.training.metrics import Metric as allennlp_Metric | ||||
| if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE: | |||||
| if _IS_TORCHMETRICS_AVAILABLE: | |||||
| from torchmetrics import Metric as torchmetrics_Metric | |||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| from paddle.metric import Metric as paddle_Metric | from paddle.metric import Metric as paddle_Metric | ||||
| @@ -9,6 +9,7 @@ __all__ = [ | |||||
| ] | ] | ||||
| import os | import os | ||||
| import re | |||||
| from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
| @@ -42,10 +43,19 @@ def get_paddle_device_id(device: Union[str, int]): | |||||
| if isinstance(device, int): | if isinstance(device, int): | ||||
| return device | return device | ||||
| device = device.lower() | |||||
| if device == "cpu": | if device == "cpu": | ||||
| raise ValueError("Cannot get device id from `cpu`.") | raise ValueError("Cannot get device id from `cpu`.") | ||||
| return paddle.device._convert_to_place(device).get_device_id() | |||||
| match_res = re.match(r"gpu:\d+", device) | |||||
| if not match_res: | |||||
| raise ValueError( | |||||
| "The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" | |||||
| ) | |||||
| device_id = device.split(':', 1)[1] | |||||
| device_id = int(device_id) | |||||
| return device_id | |||||
| def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | ||||
| data_device: Optional[str] = None) -> Any: | data_device: Optional[str] = None) -> Any: | ||||
| @@ -52,21 +52,33 @@ def _set_backend(): | |||||
| if backend == 'paddle': | if backend == 'paddle': | ||||
| assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
| assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." | assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." | ||||
| if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'PADDLE_RANK_IN_NODE' not in os.environ \ | |||||
| and 'FLAGS_selected_gpus' not in os.environ: | |||||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
| os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
| user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||||
| if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||||
| # 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | |||||
| selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | |||||
| if user_visible_devices is not None and user_visible_devices != "": | |||||
| # 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | |||||
| # 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | |||||
| # 我们需要从中找到真正使用的设备编号 | |||||
| user_visible_devices = user_visible_devices.split(",") | |||||
| selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | |||||
| else: | |||||
| # 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 | |||||
| os.environ[USER_CUDA_VISIBLE_DEVICES] = "" | |||||
| # TODO 这里的 [0] 可能在单个节点多卡的时候有问题 | |||||
| os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] | |||||
| os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||||
| os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||||
| elif 'CUDA_VISIBLE_DEVICES' in os.environ: | elif 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
| # 主进程中,用户设置了 CUDA_VISIBLE_DEVICES | |||||
| # 将用户设置的 CUDA_VISIBLE_DEVICES hack 掉 | |||||
| CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] | CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] | ||||
| os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | ||||
| os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | ||||
| elif 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||||
| # TODO 这里由于fastNLP需要hack CUDA_VISIBLE_DEVICES,因此需要相应滴修改FLAGS等paddle变量 @xsh | |||||
| CUDA_VISIBLE_DEVICES = os.environ['FLAGS_selected_gpus'] | |||||
| os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | |||||
| os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | |||||
| os.environ['FLAGS_selected_gpus'] = "0" | |||||
| os.environ['FLAGS_selected_accelerators'] = "0" | |||||
| else: | |||||
| # 没有设置的话限制在单卡上,防止多进程时占用别的卡 | |||||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
| os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
| elif backend == 'jittor': | elif backend == 'jittor': | ||||
| assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
| @@ -36,8 +36,14 @@ def set_env_on_import_torch(): | |||||
| # TODO paddle may need set this | # TODO paddle may need set this | ||||
| def set_env_on_import_paddle(): | def set_env_on_import_paddle(): | ||||
| # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | |||||
| pass | |||||
| # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS | |||||
| if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||||
| and "PADDLE_RANK_IN_NODE" in os.environ: | |||||
| # 检测到了分布式环境的环境变量 | |||||
| os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | |||||
| # 如果不是由 fastnlp 启动的 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| os.environ[FASTNLP_BACKEND_LAUNCH] = "1" | |||||
| # TODO jittor may need set this | # TODO jittor may need set this | ||||
| def set_env_on_import_jittor(): | def set_env_on_import_jittor(): | ||||
| @@ -0,0 +1,151 @@ | |||||
| import pytest | |||||
| import os | |||||
| from typing import Any | |||||
| from dataclasses import dataclass | |||||
| from paddle.optimizer import Adam | |||||
| from paddle.io import DataLoader | |||||
| from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.core.metrics.accuracy import Accuracy | |||||
| from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
| from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||||
| from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | |||||
| from tests.helpers.utils import magic_argv_env_context | |||||
| @dataclass | |||||
| class MNISTTrainPaddleConfig: | |||||
| num_labels: int = 10 | |||||
| feature_dimension: int = 784 | |||||
| batch_size: int = 32 | |||||
| shuffle: bool = True | |||||
| validate_every = -5 | |||||
| driver: str = "paddle" | |||||
| device = "gpu" | |||||
| @dataclass | |||||
| class MNISTTrainFleetConfig: | |||||
| num_labels: int = 10 | |||||
| feature_dimension: int = 784 | |||||
| batch_size: int = 32 | |||||
| shuffle: bool = True | |||||
| validate_every = -5 | |||||
| @dataclass | |||||
| class TrainerParameters: | |||||
| model: Any = None | |||||
| optimizers: Any = None | |||||
| train_dataloader: Any = None | |||||
| validate_dataloaders: Any = None | |||||
| input_mapping: Any = None | |||||
| output_mapping: Any = None | |||||
| metrics: Any = None | |||||
| # @pytest.fixture(params=[0], autouse=True) | |||||
| # def model_and_optimizers(request): | |||||
| # """ | |||||
| # 初始化单卡模式的模型和优化器 | |||||
| # """ | |||||
| # trainer_params = TrainerParameters() | |||||
| # print(paddle.device.get_device()) | |||||
| # if request.param == 0: | |||||
| # trainer_params.model = PaddleNormalModel_Classification( | |||||
| # num_labels=MNISTTrainPaddleConfig.num_labels, | |||||
| # feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||||
| # ) | |||||
| # trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||||
| # train_dataloader = DataLoader( | |||||
| # dataset=PaddleDataset_MNIST("train"), | |||||
| # batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
| # shuffle=True | |||||
| # ) | |||||
| # val_dataloader = DataLoader( | |||||
| # dataset=PaddleDataset_MNIST(mode="test"), | |||||
| # batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
| # shuffle=True | |||||
| # ) | |||||
| # trainer_params.train_dataloader = train_dataloader | |||||
| # trainer_params.validate_dataloaders = val_dataloader | |||||
| # trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||||
| # trainer_params.metrics = {"acc": Accuracy()} | |||||
| # return trainer_params | |||||
| @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) | |||||
| # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | |||||
| @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), | |||||
| RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) | |||||
| @magic_argv_env_context | |||||
| def test_trainer_paddle( | |||||
| # model_and_optimizers: TrainerParameters, | |||||
| driver, | |||||
| device, | |||||
| callbacks, | |||||
| n_epochs=15, | |||||
| ): | |||||
| trainer_params = TrainerParameters() | |||||
| trainer_params.model = PaddleNormalModel_Classification( | |||||
| num_labels=MNISTTrainPaddleConfig.num_labels, | |||||
| feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||||
| ) | |||||
| trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||||
| train_dataloader = DataLoader( | |||||
| dataset=PaddleDataset_MNIST("train"), | |||||
| batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| val_dataloader = DataLoader( | |||||
| dataset=PaddleDataset_MNIST(mode="test"), | |||||
| batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| trainer_params.train_dataloader = train_dataloader | |||||
| trainer_params.validate_dataloaders = val_dataloader | |||||
| trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||||
| trainer_params.metrics = {"acc": Accuracy(backend="paddle")} | |||||
| if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as exc: | |||||
| trainer = Trainer( | |||||
| model=trainer_params.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=trainer_params.optimizers, | |||||
| train_dataloader=trainer_params.train_dataloader, | |||||
| validate_dataloaders=trainer_params.validate_dataloaders, | |||||
| validate_every=trainer_params.validate_every, | |||||
| input_mapping=trainer_params.input_mapping, | |||||
| output_mapping=trainer_params.output_mapping, | |||||
| metrics=trainer_params.metrics, | |||||
| n_epochs=n_epochs, | |||||
| callbacks=callbacks, | |||||
| ) | |||||
| assert exc.value.code == 0 | |||||
| return | |||||
| else: | |||||
| trainer = Trainer( | |||||
| model=trainer_params.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=trainer_params.optimizers, | |||||
| train_dataloader=trainer_params.train_dataloader, | |||||
| validate_dataloaders=trainer_params.validate_dataloaders, | |||||
| validate_every=trainer_params.validate_every, | |||||
| input_mapping=trainer_params.input_mapping, | |||||
| output_mapping=trainer_params.output_mapping, | |||||
| metrics=trainer_params.metrics, | |||||
| n_epochs=n_epochs, | |||||
| callbacks=callbacks, | |||||
| ) | |||||
| trainer.run() | |||||
| @@ -1,17 +1,11 @@ | |||||
| import unittest | import unittest | ||||
| import torch | import torch | ||||
| from fastNLP.envs.set_env import set_env | |||||
| from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
| set_env_on_import_paddle() | |||||
| set_env("paddle") | |||||
| from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||||
| import paddle | import paddle | ||||
| from paddle.io import Dataset, DataLoader | from paddle.io import Dataset, DataLoader | ||||
| from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||||
| class Net(paddle.nn.Layer): | class Net(paddle.nn.Layer): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||