|
|
@@ -1,75 +1,28 @@ |
|
|
|
import unittest |
|
|
|
|
|
|
|
import torch |
|
|
|
import os |
|
|
|
import pytest |
|
|
|
os.environ["FASTNLP_BACKEND"] = "paddle" |
|
|
|
|
|
|
|
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver |
|
|
|
import paddle |
|
|
|
from paddle.io import Dataset, DataLoader |
|
|
|
|
|
|
|
class Net(paddle.nn.Layer): |
|
|
|
def __init__(self): |
|
|
|
super(Net, self).__init__() |
|
|
|
|
|
|
|
self.fc1 = paddle.nn.Linear(784, 64) |
|
|
|
self.fc2 = paddle.nn.Linear(64, 32) |
|
|
|
self.fc3 = paddle.nn.Linear(32, 10) |
|
|
|
self.fc4 = paddle.nn.Linear(10, 10) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
x = self.fc1(x) |
|
|
|
x = self.fc2(x) |
|
|
|
x = self.fc3(x) |
|
|
|
x = self.fc4(x) |
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class PaddleDataset(Dataset): |
|
|
|
def __init__(self): |
|
|
|
super(PaddleDataset, self).__init__() |
|
|
|
self.items = [paddle.rand((3, 4)) for i in range(320)] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.items) |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
return self.items[idx] |
|
|
|
|
|
|
|
|
|
|
|
class TorchNet(torch.nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super(TorchNet, self).__init__() |
|
|
|
|
|
|
|
self.torch_fc1 = torch.nn.Linear(10, 10) |
|
|
|
self.torch_softmax = torch.nn.Softmax(0) |
|
|
|
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) |
|
|
|
self.torch_tensor = torch.ones(3, 3) |
|
|
|
self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) |
|
|
|
|
|
|
|
|
|
|
|
class TorchDataset(torch.utils.data.Dataset): |
|
|
|
def __init__(self): |
|
|
|
super(TorchDataset, self).__init__() |
|
|
|
self.items = [torch.ones(3, 4) for i in range(320)] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.items) |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
return self.items[idx] |
|
|
|
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 |
|
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset |
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset |
|
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
|
|
|
|
|
import torch |
|
|
|
import paddle |
|
|
|
from paddle.io import DataLoader |
|
|
|
|
|
|
|
class PaddleDriverTestCase(unittest.TestCase): |
|
|
|
class TestPaddleDriverFunctions: |
|
|
|
""" |
|
|
|
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 |
|
|
|
PaddleDriver的测试类,使用仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 |
|
|
|
""" |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
model = Net() |
|
|
|
@classmethod |
|
|
|
def setup_class(self): |
|
|
|
model = PaddleNormalModel_Classification_1(10, 32) |
|
|
|
self.driver = PaddleDriver(model) |
|
|
|
|
|
|
|
def test_check_single_optimizer_legacy(self): |
|
|
|
def test_check_single_optimizer_legality(self): |
|
|
|
""" |
|
|
|
测试传入单个optimizer时的表现 |
|
|
|
""" |
|
|
@@ -80,12 +33,12 @@ class PaddleDriverTestCase(unittest.TestCase): |
|
|
|
|
|
|
|
self.driver.set_optimizers(optimizer) |
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) |
|
|
|
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) |
|
|
|
# 传入torch的optimizer时,应该报错ValueError |
|
|
|
with self.assertRaises(ValueError) as cm: |
|
|
|
self.driver.set_optimizers(optimizer) |
|
|
|
|
|
|
|
def test_check_optimizers_legacy(self): |
|
|
|
def test_check_optimizers_legality(self): |
|
|
|
""" |
|
|
|
测试传入optimizer list的表现 |
|
|
|
""" |
|
|
@@ -99,22 +52,27 @@ class PaddleDriverTestCase(unittest.TestCase): |
|
|
|
self.driver.set_optimizers(optimizers) |
|
|
|
|
|
|
|
optimizers += [ |
|
|
|
torch.optim.Adam(TorchNet().parameters(), 0.01) |
|
|
|
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) |
|
|
|
] |
|
|
|
|
|
|
|
with self.assertRaises(ValueError) as cm: |
|
|
|
self.driver.set_optimizers(optimizers) |
|
|
|
|
|
|
|
def test_check_dataloader_legacy_in_train(self): |
|
|
|
def test_check_dataloader_legality_in_train(self): |
|
|
|
""" |
|
|
|
测试is_train参数为True时,_check_dataloader_legality函数的表现 |
|
|
|
""" |
|
|
|
dataloader = paddle.io.DataLoader(PaddleDataset()) |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
|
|
|
|
# batch_size 和 batch_sampler 均为 None 的情形 |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) |
|
|
|
with self.assertRaises(ValueError) as cm: |
|
|
|
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
|
|
|
|
# 创建torch的dataloader |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
|
TorchDataset(), |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
) |
|
|
|
with self.assertRaises(ValueError) as cm: |
|
|
@@ -125,21 +83,31 @@ class PaddleDriverTestCase(unittest.TestCase): |
|
|
|
测试is_train参数为False时,_check_dataloader_legality函数的表现 |
|
|
|
""" |
|
|
|
# 此时传入的应该是dict |
|
|
|
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} |
|
|
|
dataloader = { |
|
|
|
"train": paddle.io.DataLoader(PaddleNormalDataset()), |
|
|
|
"test":paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
} |
|
|
|
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
# batch_size 和 batch_sampler 均为 None 的情形 |
|
|
|
dataloader = { |
|
|
|
"train": paddle.io.DataLoader(PaddleNormalDataset()), |
|
|
|
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) |
|
|
|
} |
|
|
|
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
# 传入的不是dict,应该报错 |
|
|
|
dataloader = paddle.io.DataLoader(PaddleDataset()) |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
with self.assertRaises(ValueError) as cm: |
|
|
|
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
# 创建torch的dataloader |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
|
TorchDataset(), |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
) |
|
|
|
test_loader = torch.utils.data.DataLoader( |
|
|
|
TorchDataset(), |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
) |
|
|
|
dataloader = {"train": train_loader, "test": test_loader} |
|
|
@@ -240,7 +208,7 @@ class PaddleDriverTestCase(unittest.TestCase): |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
dataloader = DataLoader(PaddleDataset()) |
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
self.driver.set_deterministic_dataloader(dataloader) |
|
|
|
|
|
|
|
def test_set_sampler_epoch(self): |
|
|
@@ -249,7 +217,7 @@ class PaddleDriverTestCase(unittest.TestCase): |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
dataloader = DataLoader(PaddleDataset()) |
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
self.driver.set_sampler_epoch(dataloader, 0) |
|
|
|
|
|
|
|
def test_get_dataloader_args(self): |
|
|
@@ -258,5 +226,5 @@ class PaddleDriverTestCase(unittest.TestCase): |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
dataloader = DataLoader(PaddleDataset()) |
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
res = PaddleDriver.get_dataloader_args(dataloader) |