Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
10db8c9373
10 changed files with 241 additions and 49 deletions
  1. +3
    -3
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  2. +20
    -3
      fastNLP/core/drivers/paddle_driver/single_device.py
  3. +1
    -1
      fastNLP/core/drivers/paddle_driver/utils.py
  4. +18
    -17
      fastNLP/core/metrics/backend/paddle_backend/backend.py
  5. +5
    -4
      fastNLP/core/metrics/utils.py
  6. +11
    -1
      fastNLP/core/utils/paddle_utils.py
  7. +23
    -11
      fastNLP/envs/set_backend.py
  8. +8
    -2
      fastNLP/envs/set_env_on_import.py
  9. +151
    -0
      tests/core/controllers/test_trainer_paddle.py
  10. +1
    -7
      tests/core/drivers/paddle_driver/test_paddle_driver.py

+ 3
- 3
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -42,7 +42,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
# 优先级 user > cuda # 优先级 user > cuda
# 判断单机情况 device 的合法性 # 判断单机情况 device 的合法性
# 分布式情况下通过 world_device 判断 # 分布式情况下通过 world_device 判断
if user_visible_devices is not None:
if user_visible_devices != "":
_could_use_device_num = len(user_visible_devices.split(",")) _could_use_device_num = len(user_visible_devices.split(","))
elif cuda_visible_devices is not None: elif cuda_visible_devices is not None:
_could_use_device_num = len(cuda_visible_devices.split(",")) _could_use_device_num = len(cuda_visible_devices.split(","))
@@ -51,8 +51,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
if isinstance(device, int): if isinstance(device, int):
if device < 0 and device != -1: if device < 0 and device != -1:
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
if device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
# if device >= _could_use_device_num:
# raise ValueError("The gpu device that parameter `device` specifies is not existed.")
device = f"gpu:{device}" device = f"gpu:{device}"
elif isinstance(device, Sequence) and not isinstance(device, str): elif isinstance(device, Sequence) and not isinstance(device, str):
device = list(set(device)) device = list(set(device))


+ 20
- 3
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -1,8 +1,14 @@
import os
from typing import Optional, Dict, Union from typing import Optional, Dict, Union


from .paddle_driver import PaddleDriver from .paddle_driver import PaddleDriver
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import auto_param_call, get_paddle_gpu_str
from fastNLP.core.utils import (
auto_param_call,
get_paddle_gpu_str,
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -86,8 +92,9 @@ class PaddleSingleDriver(PaddleDriver):
self._test_signature_fn = model.forward self._test_signature_fn = model.forward


def setup(self): def setup(self):
paddle.device.set_device(self.model_device)
self.model.to(self.model_device)
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device))
paddle.device.set_device("gpu:0")
self.model.to("gpu:0")


def train_step(self, batch) -> Dict: def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
@@ -116,6 +123,16 @@ class PaddleSingleDriver(PaddleDriver):
else: else:
return self._test_step(batch) return self._test_step(batch)


def move_data_to_device(self, batch: 'paddle.Tensor'):
r"""
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0`

:return: 将移动到指定机器上的 batch 对象返回;
"""
return paddle_move_data_to_device(batch, "gpu:0")

def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False):
# 暂时不支持IteratorDataset # 暂时不支持IteratorDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \


+ 1
- 1
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -272,7 +272,7 @@ def get_device_from_visible(device: Union[str, int]):
else: else:
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visiblde_devices is None or user_visiblde_devices != "":
if user_visiblde_devices is not None and user_visiblde_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
idx = user_visiblde_devices.split(",")[idx] idx = user_visiblde_devices.split(",")[idx]
else: else:


+ 18
- 17
fastNLP/core/metrics/backend/paddle_backend/backend.py View File

@@ -11,11 +11,12 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
import paddle import paddle
import paddle.distributed as dist
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper


def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] gathered_result = [paddle.zeros_like(result) for _ in range(world_size)]
paddle.distributed.all_gather(gathered_result, result, group)
dist.all_gather(gathered_result, result, group)
return gathered_result return gathered_result


class PaddleBackend(Backend): class PaddleBackend(Backend):
@@ -36,13 +37,13 @@ class PaddleBackend(Backend):
tensor = paddle.stack(tensor) tensor = paddle.stack(tensor)
# 第一步, aggregate结果 # 第一步, aggregate结果
if method == 'sum': if method == 'sum':
tensor = paddle.sum(tensor, dim=0)
tensor = paddle.sum(tensor, axis=0)
elif method == 'mean': elif method == 'mean':
tensor = paddle.mean(tensor, dim=0)
tensor = paddle.mean(tensor, axis=0)
elif method == 'max': elif method == 'max':
tensor, _ = paddle.max(tensor, dim=0)
tensor, _ = paddle.max(tensor, axis=0)
elif method == 'min': elif method == 'min':
tensor, _ = paddle.min(tensor, dim=0)
tensor, _ = paddle.min(tensor, axis=0)
else: else:
raise AggregateMethodError(should_have_aggregate_method=False) raise AggregateMethodError(should_have_aggregate_method=False)


@@ -80,11 +81,12 @@ class PaddleBackend(Backend):
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding 聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding
""" """
# TODO check 正确性 # TODO check 正确性
if group is None:
group = paddle.distributed.get_group(0)
# 有 paddle 那边的 bug,2.3 版本的时候修复了,到时候改一下
# if group is None:
# group = dist.get_group(0)


world_size = group.nranks
paddle.distributed.barrier(group=group)
world_size = group.nranks if group is not None else dist.get_world_size()
dist.barrier(group=group)


# 张量为 标量的情况,简单地gather就好 # 张量为 标量的情况,简单地gather就好
if result.ndim == 0: if result.ndim == 0:
@@ -93,10 +95,10 @@ class PaddleBackend(Backend):
# 获得 result 的 shape # 获得 result 的 shape
local_size = paddle.to_tensor(result.shape) local_size = paddle.to_tensor(result.shape)
# 将 group 中所有 result 的大小聚合在一起 # 将 group 中所有 result 的大小聚合在一起
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)]
paddle.distributed.all_gather(local_sizes, local_size, group=group)
local_sizes = []
dist.all_gather(local_sizes, local_size, group=group)
# 堆叠后,计算出 shape 每一维度的最大值 # 堆叠后,计算出 shape 每一维度的最大值
max_size = paddle.stack(local_sizes).max(axis=0).values
max_size = paddle.stack(local_sizes).max(axis=0)
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)


# 如果所有的结果大小相同,那么可以直接聚合 # 如果所有的结果大小相同,那么可以直接聚合
@@ -111,16 +113,15 @@ class PaddleBackend(Backend):
pad_dims.append(val.item()) pad_dims.append(val.item())
result_padded = paddle.nn.functional.pad(result, pad_dims) result_padded = paddle.nn.functional.pad(result, pad_dims)
# 重新进行聚合 # 重新进行聚合
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)]
paddle.distributed.all_gather(gathered_result, result_padded, group)
gathered_result = []
dist.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes): for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
slice_param = [slice(dim_size) for dim_size in item_size.tolist()]
gathered_result[idx] = gathered_result[idx][slice_param] gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result return gathered_result


def move_tensor_to_device(self, tensor, device): def move_tensor_to_device(self, tensor, device):
# TODO 如果在这里处理的话,会不会在别的地方引起bug? # TODO 如果在这里处理的话,会不会在别的地方引起bug?
if is_in_paddle_dist():
device = get_device_from_visible(device)
device = get_device_from_visible(device)
return paddle_to(tensor, device) return paddle_to(tensor, device)



+ 5
- 4
fastNLP/core/metrics/utils.py View File

@@ -4,17 +4,18 @@ __all__ = [


from typing import Any from typing import Any
from functools import wraps from functools import wraps
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
from fastNLP.envs.utils import _module_available from fastNLP.envs.utils import _module_available


_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') _IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics')
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric

_IS_ALLENNLP_AVAILABLE = _module_available('allennlp') _IS_ALLENNLP_AVAILABLE = _module_available('allennlp')
if _IS_ALLENNLP_AVAILABLE: if _IS_ALLENNLP_AVAILABLE:
from allennlp.training.metrics import Metric as allennlp_Metric from allennlp.training.metrics import Metric as allennlp_Metric


if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE:
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric

if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
from paddle.metric import Metric as paddle_Metric from paddle.metric import Metric as paddle_Metric




+ 11
- 1
fastNLP/core/utils/paddle_utils.py View File

@@ -9,6 +9,7 @@ __all__ = [
] ]


import os import os
import re
from typing import Any, Optional, Union from typing import Any, Optional, Union


from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
@@ -42,10 +43,19 @@ def get_paddle_device_id(device: Union[str, int]):
if isinstance(device, int): if isinstance(device, int):
return device return device


device = device.lower()
if device == "cpu": if device == "cpu":
raise ValueError("Cannot get device id from `cpu`.") raise ValueError("Cannot get device id from `cpu`.")


return paddle.device._convert_to_place(device).get_device_id()
match_res = re.match(r"gpu:\d+", device)
if not match_res:
raise ValueError(
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'"
)
device_id = device.split(':', 1)[1]
device_id = int(device_id)

return device_id


def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, def paddle_move_data_to_device(batch: Any, device: Optional[str] = None,
data_device: Optional[str] = None) -> Any: data_device: Optional[str] = None) -> Any:


+ 23
- 11
fastNLP/envs/set_backend.py View File

@@ -52,21 +52,33 @@ def _set_backend():
if backend == 'paddle': if backend == 'paddle':
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." assert _module_available(backend), f"You must have {backend} available to use {backend} backend."
assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`."
if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'PADDLE_RANK_IN_NODE' not in os.environ \
and 'FLAGS_selected_gpus' not in os.environ:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ[USER_CUDA_VISIBLE_DEVICES] = ''
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ:
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',')
if user_visible_devices is not None and user_visible_devices != "":
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中
# 我们需要从中找到真正使用的设备编号
user_visible_devices = user_visible_devices.split(",")
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus])
else:
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见
os.environ[USER_CUDA_VISIBLE_DEVICES] = ""
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0]
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))])
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))])
elif 'CUDA_VISIBLE_DEVICES' in os.environ: elif 'CUDA_VISIBLE_DEVICES' in os.environ:
# 主进程中,用户设置了 CUDA_VISIBLE_DEVICES
# 将用户设置的 CUDA_VISIBLE_DEVICES hack 掉
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES']
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0]
elif 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ:
# TODO 这里由于fastNLP需要hack CUDA_VISIBLE_DEVICES,因此需要相应滴修改FLAGS等paddle变量 @xsh
CUDA_VISIBLE_DEVICES = os.environ['FLAGS_selected_gpus']
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0]
os.environ['FLAGS_selected_gpus'] = "0"
os.environ['FLAGS_selected_accelerators'] = "0"
else:
# 没有设置的话限制在单卡上,防止多进程时占用别的卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ[USER_CUDA_VISIBLE_DEVICES] = ''


elif backend == 'jittor': elif backend == 'jittor':
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." assert _module_available(backend), f"You must have {backend} available to use {backend} backend."


+ 8
- 2
fastNLP/envs/set_env_on_import.py View File

@@ -36,8 +36,14 @@ def set_env_on_import_torch():


# TODO paddle may need set this # TODO paddle may need set this
def set_env_on_import_paddle(): def set_env_on_import_paddle():
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH
pass
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \
and "PADDLE_RANK_IN_NODE" in os.environ:
# 检测到了分布式环境的环境变量
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"]
# 如果不是由 fastnlp 启动的
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
os.environ[FASTNLP_BACKEND_LAUNCH] = "1"


# TODO jittor may need set this # TODO jittor may need set this
def set_env_on_import_jittor(): def set_env_on_import_jittor():


+ 151
- 0
tests/core/controllers/test_trainer_paddle.py View File

@@ -0,0 +1,151 @@
import pytest
import os
from typing import Any
from dataclasses import dataclass

from paddle.optimizer import Adam
from paddle.io import DataLoader

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK


from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context

@dataclass
class MNISTTrainPaddleConfig:
num_labels: int = 10
feature_dimension: int = 784

batch_size: int = 32
shuffle: bool = True
validate_every = -5

driver: str = "paddle"
device = "gpu"

@dataclass
class MNISTTrainFleetConfig:
num_labels: int = 10
feature_dimension: int = 784

batch_size: int = 32
shuffle: bool = True
validate_every = -5

@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
validate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None

# @pytest.fixture(params=[0], autouse=True)
# def model_and_optimizers(request):
# """
# 初始化单卡模式的模型和优化器
# """
# trainer_params = TrainerParameters()
# print(paddle.device.get_device())

# if request.param == 0:
# trainer_params.model = PaddleNormalModel_Classification(
# num_labels=MNISTTrainPaddleConfig.num_labels,
# feature_dimension=MNISTTrainPaddleConfig.feature_dimension
# )
# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
# train_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST("train"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# val_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST(mode="test"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# trainer_params.train_dataloader = train_dataloader
# trainer_params.validate_dataloaders = val_dataloader
# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
# trainer_params.metrics = {"acc": Accuracy()}

# return trainer_params


@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
@magic_argv_env_context
def test_trainer_paddle(
# model_and_optimizers: TrainerParameters,
driver,
device,
callbacks,
n_epochs=15,
):
trainer_params = TrainerParameters()

trainer_params.model = PaddleNormalModel_Classification(
num_labels=MNISTTrainPaddleConfig.num_labels,
feature_dimension=MNISTTrainPaddleConfig.feature_dimension
)
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader(
dataset=PaddleDataset_MNIST("train"),
batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleDataset_MNIST(mode="test"),
batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True
)
trainer_params.train_dataloader = train_dataloader
trainer_params.validate_dataloaders = val_dataloader
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as exc:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
assert exc.value.code == 0
return
else:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()

+ 1
- 7
tests/core/drivers/paddle_driver/test_paddle_driver.py View File

@@ -1,17 +1,11 @@
import unittest import unittest


import torch import torch
from fastNLP.envs.set_env import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle


set_env_on_import_paddle()
set_env("paddle")
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver
import paddle import paddle
from paddle.io import Dataset, DataLoader from paddle.io import Dataset, DataLoader


from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver


class Net(paddle.nn.Layer): class Net(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()


Loading…
Cancel
Save