Browse Source

1. 支持在不设置backend的情况下运行单卡的paddle程序 2.当通过launch启动且限制显卡时的paddle多卡逻辑

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
ef892a7aed
14 changed files with 477 additions and 83 deletions
  1. +15
    -6
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +5
    -2
      fastNLP/core/drivers/paddle_driver/fleet_launcher.py
  3. +12
    -10
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  4. +6
    -1
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  5. +8
    -9
      fastNLP/core/drivers/paddle_driver/single_device.py
  6. +4
    -1
      fastNLP/core/metrics/backend/paddle_backend/backend.py
  7. +18
    -24
      fastNLP/core/utils/paddle_utils.py
  8. +14
    -4
      fastNLP/envs/set_backend.py
  9. +13
    -6
      tests/core/controllers/_test_trainer_fleet.py
  10. +11
    -6
      tests/core/controllers/_test_trainer_fleet_outside.py
  11. +237
    -0
      tests/core/controllers/_test_trainer_jittor.py
  12. +110
    -0
      tests/core/controllers/imdb.py
  13. +5
    -0
      tests/core/controllers/test_trainer_paddle.py
  14. +19
    -14
      tests/core/utils/test_paddle_utils.py

+ 15
- 6
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -19,6 +19,7 @@ from fastNLP.core.utils import (
check_user_specific_params,
is_in_paddle_dist,
is_in_paddle_dist,
get_paddle_device_id,
)
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP.core.samplers import (
@@ -31,7 +32,12 @@ from fastNLP.core.samplers import (
re_instantiate_sampler,
conversion_between_reproducible_and_unrepeated_sampler,
)
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
from fastNLP.envs.env import (
FASTNLP_DISTRIBUTED_CHECK,
FASTNLP_GLOBAL_SEED,
FASTNLP_NO_SYNC,
USER_CUDA_VISIBLE_DEVICES,
)
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
@@ -51,7 +57,7 @@ class PaddleFleetDriver(PaddleDriver):
def __init__(
self,
model,
parallel_device: Optional[Union[List[int], int]],
parallel_device: Optional[Union[List[str], str]],
is_pull_by_paddle_run: bool = False,
fp16: bool = False,
**kwargs
@@ -185,6 +191,8 @@ class PaddleFleetDriver(PaddleDriver):
不管是什么情况,`PaddleFleetDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后,
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉;
"""
# if USER_CUDA_VISIBLE_DEVICES not in os.environ:
# raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.")
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs)

# 如果不是通过 launch 启动,要求用户必须传入 parallel_device
@@ -229,9 +237,9 @@ class PaddleFleetDriver(PaddleDriver):
self._data_device = f"gpu:{self._data_device}"
elif not isinstance(self._data_device, str):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
if self.outside_fleet and paddle.device.get_device() != self._data_device:
logger.warning("`Parameter data_device` is not equal to paddle.deivce.get_device(), "
"please keep them equal to avoid some potential bugs.")
# if self.outside_fleet and paddle.device.get_device() != self._data_device:
# logger.warning("`Parameter data_device` is not equal to paddle.deivce.get_device(), "
# "please keep them equal to avoid some potential bugs.")

self.world_size = None
self.global_rank = 0
@@ -304,7 +312,8 @@ class PaddleFleetDriver(PaddleDriver):
else:
# 已经设置过一次,保证参数必须是一样的
pre_gpus = os.environ[FASTNLP_DISTRIBUTED_CHECK]
pre_gpus = [int (x) for x in pre_gpus.split(",")]
pre_gpus = [int(x) for x in pre_gpus.split(",")]
cur_gpus = [get_paddle_device_id(g) for g in self.parallel_device]
if sorted(pre_gpus) != sorted(self.parallel_device):
raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not"
"allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.")


+ 5
- 2
fastNLP/core/drivers/paddle_driver/fleet_launcher.py View File

@@ -11,11 +11,14 @@ from fastNLP.envs.env import (
FASTNLP_LOG_LEVEL,
FASTNLP_GLOBAL_SEED,
)
from fastNLP.core.utils import get_paddle_device_id
from .utils import (
find_free_ports,
reset_seed,
)

__all__ = []

# 记录各个进程信息
class SubTrainer(object):
"""
@@ -34,11 +37,11 @@ class FleetLauncher:
"""
def __init__(
self,
devices: List[int],
devices: List[str],
output_from_new_proc: str = "only_error"
):

self.devices = devices
self.devices = [ get_paddle_device_id(g) for g in devices]
self.output_from_new_proc = output_from_new_proc

self.setup()


+ 12
- 10
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -7,7 +7,7 @@ from .single_device import PaddleSingleDriver
from .fleet import PaddleFleetDriver

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import is_in_paddle_launch_dist
from fastNLP.core.utils import is_in_paddle_launch_dist, get_paddle_gpu_str
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
@@ -30,27 +30,28 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
"""
if driver != "paddle":
raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.")
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
if is_in_paddle_launch_dist():
if device is not None:
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull "
"up your script. And we will directly get the local device via "
"and `os.environ['CUDA_VISIBLE_DEVICES']``.")
device = [int(g) for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
# TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int
"up your script. And we will directly get the local device via environment variables.")
_visible_list = user_visible_devices.split(",")
device = [ f"gpu:{_visible_list.index(g) }" for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
# TODO 目前一个进程仅对应一个卡,所以暂时传入单个
return PaddleFleetDriver(model, device[0], True, **kwargs)

user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
if user_visible_devices is None:
raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set "
"`FASTNLP_BACKEND` to 'paddle' before using FastNLP.")
_could_use_device_num = len(user_visible_devices.split(","))
_could_use_device_num = paddle.device.cuda.device_count()
else:
_could_use_device_num = len(user_visible_devices.split(","))

if isinstance(device, int):
if device < 0 and device != -1:
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
if device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
if device == -1:
device = list(range(_could_use_device_num))
device = [ get_paddle_gpu_str(g) for g in range(_could_use_device_num)]
elif isinstance(device, Sequence) and not isinstance(device, str):
device = list(set(device))
for each in device:
@@ -61,6 +62,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
elif each >= _could_use_device_num:
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
" the available gpu number.")
device = [get_paddle_gpu_str(g) for g in device]
elif device is not None and not isinstance(device, str):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
if isinstance(device, List):


+ 6
- 1
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -7,6 +7,8 @@ from dataclasses import dataclass

import numpy as np

from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES

from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.drivers.driver import Driver
@@ -369,7 +371,10 @@ class PaddleDriver(Driver):

:return: 将移动到指定机器上的 batch 对象返回;
"""
device = get_device_from_visible(self.data_device)
if USER_CUDA_VISIBLE_DEVICES in os.environ:
device = get_device_from_visible(self.data_device)
else:
device = self.data_device
return paddle_move_data_to_device(batch, device)

@staticmethod


+ 8
- 9
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -40,9 +40,6 @@ class PaddleSingleDriver(PaddleDriver):
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`")

cuda_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if cuda_visible_devices is None:
raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set "
"`FASTNLP_BACKEND` to 'paddle' before using FastNLP.")
if cuda_visible_devices == "":
device = "cpu"
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
@@ -54,11 +51,9 @@ class PaddleSingleDriver(PaddleDriver):
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")

if device != "cpu":
if isinstance(device, int):
device_id = device
else:
device_id = get_paddle_device_id(device)
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices.split(",")[device_id]
device_id = get_paddle_device_id(device)
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices.split(",")[device_id]
self.model_device = get_paddle_gpu_str(device)

self.local_rank = 0
@@ -69,7 +64,11 @@ class PaddleSingleDriver(PaddleDriver):
r"""
该函数用来初始化训练环境,用于设置当前训练的设备,并将模型迁移到对应设备上。
"""
device = get_device_from_visible(self.model_device, output_type=str)
if USER_CUDA_VISIBLE_DEVICES in os.environ:
device = get_device_from_visible(self.data_device)
else:
device = self.data_device

paddle.device.set_device(device)
with contextlib.redirect_stdout(None):
self.model.to(device)


+ 4
- 1
fastNLP/core/metrics/backend/paddle_backend/backend.py View File

@@ -1,3 +1,4 @@
import os
from typing import List, Any

import numpy as np
@@ -7,6 +8,7 @@ from fastNLP.core.utils.paddle_utils import paddle_to, get_device_from_visible
from fastNLP.core.metrics.utils import AggregateMethodError
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES

if _NEED_IMPORT_PADDLE:
import paddle
@@ -79,7 +81,8 @@ class PaddleBackend(Backend):
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")

def move_tensor_to_device(self, tensor, device):
device = get_device_from_visible(device)
if USER_CUDA_VISIBLE_DEVICES in os.environ:
device = get_device_from_visible(device)
return paddle_to(tensor, device)

def all_gather_object(self, obj, group=None) -> List:


+ 18
- 24
fastNLP/core/utils/paddle_utils.py View File

@@ -21,38 +21,32 @@ if _NEED_IMPORT_PADDLE:

from .utils import apply_to_collection

def get_device_from_visible(device: Union[str, int], output_type=int):
def get_device_from_visible(device: Union[str, int]) -> str:
"""
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。
在有 ``CUDA_VISIBLE_DEVICES`` 的情况下,获取对应的设备。
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。

:param device: 未转化的设备名
:param output_type: 返回值的类型
:return: 转化后的设备id
:return: 转化后的设备,格式为 ``gpu:x``
"""
if output_type not in [int, str]:
raise ValueError("Parameter `output_type` should be one of these types: [int, str]")
if device == "cpu":
return device
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visible_devices is None:
raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set "
"`FASTNLP_BACKEND` to 'paddle' before using FastNLP.")
idx = get_paddle_device_id(device)
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
if user_visible_devices is None:
raise RuntimeError("This situation cannot happen, please report a bug to us.")
idx = user_visible_devices.split(",")[idx]

cuda_visible_devices_list = cuda_visible_devices.split(',')
if idx not in cuda_visible_devices_list:
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}]. ")
res = cuda_visible_devices_list.index(idx)
if output_type == int:
return res
if cuda_visible_devices is not None:
idx = get_paddle_device_id(device)
if user_visible_devices is not None:
# 此时一定发生在分布式的情况下,利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
idx = user_visible_devices.split(",")[idx]
else:
idx = str(idx)

cuda_visible_devices_list = cuda_visible_devices.split(',')
if idx not in cuda_visible_devices_list:
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}]. ")
return f"gpu:{cuda_visible_devices_list.index(idx)}"
else:
return f"gpu:{res}"
return get_paddle_gpu_str(device)

def paddle_to(data, device: Union[str, int]):
"""
@@ -70,7 +64,7 @@ def paddle_to(data, device: Union[str, int]):
return data.cuda(get_paddle_device_id(device))


def get_paddle_gpu_str(device: Union[str, int]):
def get_paddle_gpu_str(device: Union[str, int]) -> str:
"""
获得 `gpu:x` 类型的设备名

@@ -82,7 +76,7 @@ def get_paddle_gpu_str(device: Union[str, int]):
return f"gpu:{device}"


def get_paddle_device_id(device: Union[str, int]):
def get_paddle_device_id(device: Union[str, int]) -> int:
"""
获得 gpu 的设备id



+ 14
- 4
fastNLP/envs/set_backend.py View File

@@ -51,23 +51,33 @@ def _set_backend():
assert _module_available(backend), f"You must have {backend} available to use {backend} backend."
assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`."
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ:
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',')
if user_visible_devices is not None:
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练
# 用户使用 fastNLP 启动了分布式训练
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中
# 我们需要从中找到真正使用的设备编号
# 我们需要从中转换为用户找到真正使用的设备编号
user_visible_devices = user_visible_devices.split(",")
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus])
selected_gpus = [user_visible_devices[int(i)] for i in selected_gpus]
# 没有找到 USER_CUDA_VISIBLE_DEVICES,说明用户是直接用 launch 启动的
elif cuda_visible_devices:
# 用户设置了可见设备,需要进行转换
# 如 CUDA_VISIBLE_DEVICES = 0,2,3 --gpus=0,2,3
# 在 rank1 中此时 selected_gpus = ['1'],需要转换为设备 2
os.environ[USER_CUDA_VISIBLE_DEVICES] = cuda_visible_devices
cuda_visible_devices = cuda_visible_devices.split(",")
selected_gpus = [cuda_visible_devices[int(i)] for i in selected_gpus]
else:
# 没有找到 USER_CUDA_VISIBLE_DEVICES,则将之设置为所有的设备
# 用户没有设置可见设备,则赋值成所有的设备
os.environ[USER_CUDA_VISIBLE_DEVICES] = ",".join(map(str, list(
range(get_gpu_count())
)))
os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus)
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))])
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))])
elif 'CUDA_VISIBLE_DEVICES' in os.environ:
# 主进程中,用户设置了 CUDA_VISIBLE_DEVICES
# 将用户设置的 CUDA_VISIBLE_DEVICES hack 掉


+ 13
- 6
tests/core/controllers/_test_trainer_fleet.py View File

@@ -1,7 +1,15 @@
"""
这个文件测试用户以python -m paddle.distributed.launch 启动的情况
看看有没有用pytest执行的机会
FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet.py
这个文件测试多卡情况下使用 paddle 的情况::

>>> # 测试用 python -m paddle.distributed.launch 启动
>>> FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet.py
>>> # 测试在限制 GPU 的情况下用 python -m paddle.distributed.launch 启动
>>> CUDA_VISIBLE_DEVICES=0,2,3 FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet.py
>>> # 测试直接使用多卡
>>> FASTNLP_BACKEND=paddle python _test_trainer_fleet.py
>>> # 测试在限制 GPU 的情况下直接使用多卡
>>> CUDA_VISIBLE_DEVICES=3,4,5,6 FASTNLP_BACKEND=paddle python _test_trainer_fleet.py

"""
import os
import sys
@@ -71,14 +79,13 @@ def test_trainer_fleet(

n_epochs=n_epochs,
callbacks=callbacks,
output_from_new_proc="logs",
# output_from_new_proc="logs",
)
trainer.run()

if __name__ == "__main__":
driver = "paddle"
device = [0,2,3]
# driver = "paddle"
device = [0,1,3]
# device = 2
callbacks = [
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),


+ 11
- 6
tests/core/controllers/_test_trainer_fleet_outside.py View File

@@ -1,7 +1,11 @@
"""
这个文件测试用户以python -m paddle.distributed.launch 启动的情况
并且自己初始化了 fleet
FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet_outside.py
这个文件测试用户自己初始化分布式环境后使用 paddle 的情况:

>>> # 测试用 python -m paddle.distributed.launch 启动
>>> FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet_outside.py
>>> # 测试在限制 GPU 的情况下用 python -m paddle.distributed.launch 启动
>>> CUDA_VISIBLE_DEVICES=0,2,3 FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet_outside.py

"""
import os
import sys
@@ -63,6 +67,7 @@ def test_trainer_fleet(
validate_dataloaders = val_dataloader
validate_every = MNISTTrainFleetConfig.validate_every
metrics = {"acc": Accuracy()}
data_device = f'gpu:{os.environ["USER_CUDA_VISIBLE_DEVICES"].split(",").index(os.environ["CUDA_VISIBLE_DEVICES"])}'
trainer = Trainer(
model=model,
driver=driver,
@@ -77,14 +82,14 @@ def test_trainer_fleet(

n_epochs=n_epochs,
callbacks=callbacks,
output_from_new_proc="logs",
data_device=f"gpu:{os.environ['CUDA_VISIBLE_DEVICES']}"
# output_from_new_proc="logs",
data_device=data_device
)
trainer.run()

if __name__ == "__main__":
driver = "paddle"
device = [0,2,3]
device = [0,1,3]
callbacks = [
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
RichCallback(5),


+ 237
- 0
tests/core/controllers/_test_trainer_jittor.py View File

@@ -0,0 +1,237 @@
import os
import sys
import time
# os.environ["cuda_archs"] = "61"
# os.environ["FAS"]
os.environ["log_silent"] = "1"
sys.path.append("../../../")

from datasets import load_dataset
from datasets import DatasetDict
import jittor as jt
from jittor import nn, Module
from jittor.dataset import Dataset
jt.flags.use_cuda = True

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.core.callbacks.callback import Callback
from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader

class TextClassificationDataset(Dataset):
def __init__(self, dataset):
super(TextClassificationDataset, self).__init__()
self.dataset = dataset
self.set_attrs(total_len=len(dataset))

def __getitem__(self, idx):
return {"x": self.dataset["input_ids"][idx], "y": self.dataset["label"][idx]}


class LSTM(Module):
def __init__(self, num_of_words, hidden_size, features):

self.embedding = nn.Embedding(num_of_words, features)
self.lstm = nn.LSTM(features, hidden_size, batch_first=True)
self.layer = nn.Linear(hidden_size, 2)
self.softmax = nn.Softmax(dim=1)
self.loss_fn = nn.CrossEntropyLoss()
self.hidden_size = hidden_size
self.features = features

def init_hidden(self, x):
# batch_first
batch_size = x.shape[0]
h0 = jt.randn(1, batch_size, hidden_size)
c0 = jt.randn(1, batch_size, hidden_size)

return h0, c0

def execute(self, input_ids):

output = self.embedding(input_ids)
# TODO 去除padding
output, (h, c) = self.lstm(output, self.init_hidden(output))
# len, batch, hidden_size
output = self.layer(output[-1])

return output

def train_step(self, x, y):
x = self(x)
outputs = self.loss_fn(x, y)
return {"loss": outputs}

def evaluate_step(self, x, y):
x = self(x)
return {"pred": x, "target": y.reshape((-1,))}


class PrintWhileTrainingCallBack(Callback):
"""
通过该Callback实现训练过程中loss的输出
"""

def __init__(self, print_every_epoch, print_every_batch):
self.print_every_epoch = print_every_epoch
self.print_every_batch = print_every_batch

self.loss = 0
self.start = 0
self.epoch_start = 0

def on_train_begin(self, trainer):
"""
在训练开始前输出信息
"""
print("Start training. Total {} epochs and {} batches in each epoch.".format(
trainer.n_epochs, trainer.num_batches_per_epoch
))
self.start = time.time()

def on_before_backward(self, trainer, outputs):
"""
每次反向传播前统计loss,用于计算平均值
"""
loss = trainer.extract_loss_from_outputs(outputs)
loss = trainer.driver.tensor_to_numeric(loss)
self.loss += loss

def on_train_epoch_begin(self, trainer):
self.epoch_start = time.time()

def on_train_epoch_end(self, trainer):
"""
在每经过一定epoch或最后一个epoch时输出当前epoch的平均loss和使用时间
"""
if trainer.cur_epoch_idx % self.print_every_epoch == 0 \
or trainer.cur_epoch_idx == trainer.n_epochs:
print("Epoch: {} Loss: {} Current epoch training time: {}s".format(
trainer.cur_epoch_idx, self.loss / trainer.num_batches_per_epoch, time.time() - self.epoch_start
))
# 将loss清零
self.loss = 0
def on_train_batch_end(self, trainer):
"""
在每经过一定batch或最后一个batch时输出当前epoch截止目前的平均loss
"""
if trainer.batch_idx_in_epoch % self.print_every_batch == 0 \
or trainer.batch_idx_in_epoch == trainer.num_batches_per_epoch:
print("\tBatch: {} Loss: {}".format(
trainer.batch_idx_in_epoch, self.loss / trainer.batch_idx_in_epoch
))

def on_train_end(self, trainer):
print("Total training time: {}s".format(time.time() - self.start))

def process_data(ds: DatasetDict, vocabulary: Vocabulary, max_len=256) -> DatasetDict:
# 分词
ds = ds.map(lambda x: {"input_ids": text_to_id(vocabulary, x["text"], max_len)})
ds.set_format(type="numpy", columns=ds.column_names)
return ds

def set_vocabulary(vocab, dataset):

for data in dataset:
vocab.update(data["text"].split())
return vocab

def text_to_id(vocab, text: str, max_len):
text = text.split()
# to index
ids = [vocab.to_index(word) for word in text]
# padding
ids += [vocab.padding_idx] * (max_len - len(text))
return ids[:max_len]

def get_dataset(name, max_len, train_format="", test_format=""):

# datasets
train_dataset = load_dataset(name, split="train" + train_format).shuffle(seed=123)
test_dataset = load_dataset(name, split="test" + test_format).shuffle(seed=321)
split = train_dataset.train_test_split(test_size=0.2, seed=123)
train_dataset = split["train"]
val_dataset = split["test"]

vocab = Vocabulary()
vocab = set_vocabulary(vocab, train_dataset)
vocab = set_vocabulary(vocab, val_dataset)

train_dataset = process_data(train_dataset, vocab, max_len)
val_dataset = process_data(val_dataset, vocab, max_len)
test_dataset = process_data(test_dataset, vocab, max_len)

return TextClassificationDataset(train_dataset), TextClassificationDataset(val_dataset), \
TextClassificationDataset(test_dataset), vocab

if __name__ == "__main__":

# 训练参数
max_len = 20
epochs = 40
lr = 1
batch_size = 64

features = 100
hidden_size = 128

# 获取数据集
# imdb.py SetFit/sst2
train_data, val_data, test_data, vocab = get_dataset("SetFit/sst2", max_len, "", "")
# 使用dataloader
train_dataloader = JittorDataLoader(
dataset=train_data,
batch_size=batch_size,
shuffle=True,
num_workers=4,
)
val_dataloader = JittorDataLoader(
dataset=val_data,
batch_size=batch_size,
shuffle=True,
num_workers=4,
)
test_dataloader = JittorDataLoader(
dataset=test_data,
batch_size=1,
shuffle=False,
)

# 初始化模型
model = LSTM(len(vocab), hidden_size, features)

# 优化器
# 也可以是多个优化器的list
optimizer = nn.SGD(model.parameters(), lr)

# Metrics
metrics = {"acc": Accuracy()}

# callbacks
callbacks = [
PrintWhileTrainingCallBack(print_every_epoch=1, print_every_batch=10),
# RichCallback(), # print_every参数默认为1,即每一个batch更新一次进度条
]

trainer = Trainer(
model=model,
driver="jittor",
device=[0,1,2,3,4],
optimizers=optimizer,
train_dataloader=train_dataloader,
validate_dataloaders=val_dataloader,
validate_every=-1,
input_mapping=None,
output_mapping=None,
metrics=metrics,
n_epochs=epochs,
callbacks=callbacks,
# progress_bar="raw"
)
trainer.run()

+ 110
- 0
tests/core/controllers/imdb.py View File

@@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""IMDB movie reviews dataset."""

import datasets
from datasets.tasks import TextClassification


_DESCRIPTION = """\
Large Movie Review Dataset.
This is a dataset for binary sentiment classification containing substantially \
more data than previous benchmark datasets. We provide a set of 25,000 highly \
polar movie reviews for training, and 25,000 for testing. There is additional \
unlabeled data for use as well.\
"""

_CITATION = """\
@InProceedings{maas-EtAl:2011:ACL-HLT2011,
author = {Maas, Andrew L. and Daly, Raymond E. and Pham, Peter T. and Huang, Dan and Ng, Andrew Y. and Potts, Christopher},
title = {Learning Word Vectors for Sentiment Analysis},
booktitle = {Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies},
month = {June},
year = {2011},
address = {Portland, Oregon, USA},
publisher = {Association for Computational Linguistics},
pages = {142--150},
url = {http://www.aclweb.org/anthology/P11-1015}
}
"""

_DOWNLOAD_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"


class IMDBReviewsConfig(datasets.BuilderConfig):
"""BuilderConfig for IMDBReviews."""

def __init__(self, **kwargs):
"""BuilderConfig for IMDBReviews.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(IMDBReviewsConfig, self).__init__(version=datasets.Version("1.0.0", ""), **kwargs)


class Imdb(datasets.GeneratorBasedBuilder):
"""IMDB movie reviews dataset."""

BUILDER_CONFIGS = [
IMDBReviewsConfig(
name="plain_text",
description="Plain text",
)
]

def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{"text": datasets.Value("string"), "label": datasets.features.ClassLabel(names=["neg", "pos"])}
),
supervised_keys=None,
homepage="http://ai.stanford.edu/~amaas/data/sentiment/",
citation=_CITATION,
task_templates=[TextClassification(text_column="text", label_column="label")],
)

def _split_generators(self, dl_manager):
archive = dl_manager.download(_DOWNLOAD_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "test"}
),
datasets.SplitGenerator(
name=datasets.Split("unsupervised"),
gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train", "labeled": False},
),
]

def _generate_examples(self, files, split, labeled=True):
"""Generate aclImdb examples."""
# For labeled examples, extract the label from the path.
if labeled:
label_mapping = {"pos": 1, "neg": 0}
for path, f in files:
if path.startswith(f"aclImdb/{split}"):
label = label_mapping.get(path.split("/")[2])
if label is not None:
yield path, {"text": f.read().decode("utf-8"), "label": label}
else:
for path, f in files:
if path.startswith(f"aclImdb/{split}"):
if path.split("/")[2] == "unsup":
yield path, {"text": f.read().decode("utf-8"), "label": -1}

+ 5
- 0
tests/core/controllers/test_trainer_paddle.py View File

@@ -1,3 +1,5 @@
import os
from typing import List
import pytest
from dataclasses import dataclass

@@ -5,6 +7,7 @@ from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES

if _NEED_IMPORT_PADDLE:
from paddle.optimizer import Adam
@@ -34,6 +37,8 @@ def test_trainer_paddle(
callbacks,
n_epochs=2,
):
if isinstance(device, List) and USER_CUDA_VISIBLE_DEVICES not in os.environ:
pytest.skip("Skip test fleet if FASTNLP_BACKEND is not set to paddle.")
model = PaddleNormalModel_Classification_1(
num_labels=TrainPaddleConfig.num_labels,
feature_dimension=TrainPaddleConfig.feature_dimension


+ 19
- 14
tests/core/utils/test_paddle_utils.py View File

@@ -6,33 +6,38 @@ from fastNLP.core.utils.paddle_utils import get_device_from_visible, paddle_to,
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
import paddle

@pytest.mark.parametrize(
("user_visible_devices, cuda_visible_devices, device, output_type, correct"),
("user_visible_devices, cuda_visible_devices, device, correct"),
(
("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"),
("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"),
("3,4,5,6", "3,5", 0, int, 0),
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"),
(None, None, 1, "gpu:1"),
(None, "2,4,5,6", 5, "gpu:2"),
(None, "3,4,5", 4, "gpu:1"),
("0,1,2,3,4,5,6,7", "0", "cpu", "cpu"),
("0,1,2,3,4,5,6,7", "0", "cpu", "cpu"),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", "gpu:1"),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", "gpu:2"),
("3,4,5,6", "3,5", 0, "gpu:0"),
("3,6,7,8", "6,7,8", "gpu:2", "gpu:1"),
)
)
@pytest.mark.paddle
def test_get_device_from_visible(user_visible_devices, cuda_visible_devices, device, output_type, correct):
def test_get_device_from_visible(user_visible_devices, cuda_visible_devices, device, correct):
_cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
_user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices
res = get_device_from_visible(device, output_type)
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
if user_visible_devices is not None:
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices
res = get_device_from_visible(device)
assert res == correct

# 还原环境变量
if _cuda_visible_devices is None:
del os.environ["CUDA_VISIBLE_DEVICES"]
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = _cuda_visible_devices
if _user_visible_devices is None:
del os.environ["USER_CUDA_VISIBLE_DEVICES"]
os.environ.pop("USER_CUDA_VISIBLE_DEVICES", None)
else:
os.environ["USER_CUDA_VISIBLE_DEVICES"] = _user_visible_devices



Loading…
Cancel
Save