Browse Source

paddle driver单卡和utils的pytest测试,添加了断点重训的测试

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
3ab93b2fae
4 changed files with 367 additions and 51 deletions
  1. +16
    -2
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  2. +6
    -10
      fastNLP/core/drivers/paddle_driver/utils.py
  3. +291
    -37
      tests/core/drivers/paddle_driver/test_single_device.py
  4. +54
    -2
      tests/core/drivers/paddle_driver/test_utils.py

+ 16
- 2
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -191,6 +191,8 @@ class PaddleDriver(Driver):
:return: :return:
""" """
model = self.unwrap_model() model = self.unwrap_model()
if isinstance(filepath, Path):
filepath = str(filepath)
if only_state_dict: if only_state_dict:
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
paddle.save(states, filepath) paddle.save(states, filepath)
@@ -211,6 +213,8 @@ class PaddleDriver(Driver):
:return: :return:
""" """
model = self.unwrap_model() model = self.unwrap_model()
if isinstance(filepath, Path):
filepath = str(filepath)
# paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict # paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict
# 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 # 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。
dirname, filename = os.path.split(filepath) dirname, filename = os.path.split(filepath)
@@ -274,11 +278,11 @@ class PaddleDriver(Driver):


logger.debug("Save optimizer state dict.") logger.debug("Save optimizer state dict.")
states["optimizers_state_dict"] = optimizers_state_dict states["optimizers_state_dict"] = optimizers_state_dict
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))


def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))


# 1. 加载 optimizers 的状态; # 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"] optimizers_state_dict = states["optimizers_state_dict"]
@@ -435,6 +439,16 @@ class PaddleDriver(Driver):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False
# RandomBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler
if hasattr(batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(batch_sampler.sampler, RandomSampler):
res.shuffle = True
else:
res.shuffle = False
else: else:
res.sampler = None res.sampler = None
res.shuffle = False res.shuffle = False


+ 6
- 10
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -4,12 +4,14 @@ import struct
import random import random
import inspect import inspect
import numpy as np import numpy as np
from copy import deepcopy
from contextlib import ExitStack, closing from contextlib import ExitStack, closing
from enum import IntEnum from enum import IntEnum
from typing import Dict, Optional, Union from typing import Dict, Optional, Union


from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
from fastNLP.core.samplers import RandomSampler
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -18,7 +20,7 @@ if _NEED_IMPORT_PADDLE:
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import Layer from paddle.nn import Layer
from paddle.io import DataLoader, BatchSampler
from paddle.io import DataLoader, BatchSampler, Dataset
from paddle.amp import auto_cast, GradScaler from paddle.amp import auto_cast, GradScaler
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Layer from fastNLP.core.utils.dummy_class import DummyClass as Layer
@@ -206,7 +208,6 @@ class DummyGradScaler:
def state_dict(self): def state_dict(self):
return {} return {}



def _build_fp16_env(dummy=False): def _build_fp16_env(dummy=False):
if dummy: if dummy:
auto_cast = ExitStack auto_cast = ExitStack
@@ -260,7 +261,7 @@ def get_device_from_visible(device: Union[str, int], output_type=int):
""" """
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。
:param devices: 未转化的设备名
:param device: 未转化的设备名
:param output_type: 返回值的类型 :param output_type: 返回值的类型
:return: 转化后的设备id :return: 转化后的设备id
""" """
@@ -365,13 +366,8 @@ def replace_sampler(dataloader, new_sampler):
""" """
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中
""" """
new_batch_sampler = BatchSampler(
dataset=dataloader.batch_sampler.dataset,
sampler=new_sampler,
shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler),
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.batch_sampler.drop_last
)
new_batch_sampler = deepcopy(dataloader.batch_sampler)
new_batch_sampler.sampler = new_sampler
return replace_batch_sampler(dataloader, new_batch_sampler) return replace_batch_sampler(dataloader, new_batch_sampler)


def optimizer_state_to_device(state, device): def optimizer_state_to_device(state, device):


+ 291
- 37
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,11 +1,10 @@
import os import os
from numpy import isin
os.environ["FASTNLP_BACKEND"] = "paddle" os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest import pytest
from pathlib import Path


from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.samplers import RandomBatchSampler
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -42,27 +41,101 @@ def prepare_test_save_load():
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader return driver1, driver2, dataloader


@pytest.mark.parametrize("reproducible", [True, False])
@pytest.mark.parametrize("only_state_dict", [True, False])
def test_save_and_load(prepare_test_save_load, reproducible, only_state_dict):
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
""" """
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
""" """


try: try:
path = "model.ckp" path = "model.ckp"
driver1, driver2, dataloader = prepare_test_save_load
dataloader = driver1.set_dist_repro_dataloader(dataloader, "dist", reproducible)


driver1.save(path, {}, dataloader, only_state_dict, should_save_model=True)
driver2.load(path, dataloader, only_state_dict, should_load_model=True)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)

# TODO 断点重训完善后在这里迭代几次

sampler_states = dataloader.batch_sampler.state_dict()
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
finally:
synchronize_safe_rm(path)

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)

# TODO 断点重训完善后在这里迭代几次

sampler_states = dataloader.batch_sampler.sampler.state_dict()
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]

assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"]
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]


# 3. 检查 model 的参数是否被正确加载
for batch in dataloader: for batch in dataloader:
res1 = driver1.validate_step(batch) res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch) res2 = driver2.validate_step(batch)


assert paddle.equal_all(res1["pred"], res2["pred"]) assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
finally: finally:
synchronize_safe_rm(path) synchronize_safe_rm(path)


@@ -144,24 +217,138 @@ class TestSingleDeviceFunction:
""" """
self.driver.move_data_to_device(paddle.rand((32, 64))) self.driver.move_data_to_device(paddle.rand((32, 64)))


# @pytest.mark.parametrize(
# "dist_sampler", [
# "dist",
# RandomBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False),
# RandomSampler(PaddleRandomMaxDataset(320, 10))
# ]
# )
# @pytest.mark.parametrize(
# "reproducible",
# [True, False]
# )
# def test_set_dist_repro_dataloader(self, dist_sampler, reproducible):
# """
# 测试set_dist_repro_dataloader函数
# """
# dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True)

# res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)

class TestSetDistReproDataloder:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
def setup_method(self):
self.dataset = PaddleNormalDataset(20)
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
def test_set_dist_repro_dataloader_with_reproducible_false(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader
"""
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False)

assert replaced_loader is self.dataloader

def test_set_dist_repro_dataloader_with_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler
"""
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is self.dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == self.dataloader.drop_last

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is self.dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dist_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dist = RandomSampler(self.dataset, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is self.dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size

# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2

# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。
already_seen_idx = set()
for idx, batch in replaced_loader:
already_seen_idx.update(batch)
if idx >= 1:
break
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
print(sampler_states["data_idx"])

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
replaced_loader.batch_sampler.load_state_dict(sampler_states)
else:
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(replaced_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)


class TestPaddleDriverFunctions: class TestPaddleDriverFunctions:
""" """
@@ -229,7 +416,7 @@ class TestPaddleDriverFunctions:
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)


def test_check_dataloader_legacy_in_test(self):
def test_check_dataloader_legality_in_test(self):
""" """
测试is_train参数为False时,_check_dataloader_legality函数的表现 测试is_train参数为False时,_check_dataloader_legality函数的表现
""" """
@@ -372,11 +559,78 @@ class TestPaddleDriverFunctions:
dataloader = DataLoader(PaddleNormalDataset()) dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_sampler_epoch(dataloader, 0) self.driver.set_sampler_epoch(dataloader, 0)


def test_get_dataloader_args(self):
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args(self, batch_size, shuffle, drop_last):
""" """
测试get_dataloader_args
测试正常情况下 get_dataloader_args 的表现
""" """
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleNormalDataset())
res = PaddleSingleDriver.get_dataloader_args(dataloader)
dataloader = DataLoader(
PaddleNormalDataset(),
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 batch_sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
dataloader = DataLoader(
dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size,
drop_last,
)
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last)
batch_sampler.sampler = RandomSampler(dataset, shuffle)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler,
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
assert isinstance(res.sampler, RandomSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

+ 54
- 2
tests/core/drivers/paddle_driver/test_utils.py View File

@@ -1,4 +1,56 @@
import unittest
import os
import pytest
os.environ["FASTNLP_BACKEND"] = "paddle"

from fastNLP.core.drivers.paddle_driver.utils import (
get_device_from_visible,
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler


import paddle import paddle
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
from paddle.io import DataLoader, BatchSampler

from tests.helpers.datasets.paddle_data import PaddleNormalDataset

@pytest.mark.parametrize(
("user_visible_devices, cuda_visible_devices, device, output_type, correct"),
(
("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"),
("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1),
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"),
("3,4,5,6", "3,5", 0, int, 0),
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"),
)
)
def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices
res = get_device_from_visible(device, output_type)
assert res == correct

def test_replace_batch_sampler():
dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)

replaced_loader = replace_batch_sampler(dataloader, batch_sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16

def test_replace_sampler():
dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
sampler = RandomSampler(dataset)

replaced_loader = replace_sampler(dataloader, sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)

Loading…
Cancel
Save