Browse Source

PaddleDriver的测试例调整

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
ebfa118ff2
1 changed files with 44 additions and 76 deletions
  1. +44
    -76
      tests/core/drivers/paddle_driver/test_paddle_driver.py

+ 44
- 76
tests/core/drivers/paddle_driver/test_paddle_driver.py View File

@@ -1,75 +1,28 @@
import unittest
import torch
import os
import pytest
os.environ["FASTNLP_BACKEND"] = "paddle"

from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver
import paddle
from paddle.io import Dataset, DataLoader

class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()

self.fc1 = paddle.nn.Linear(784, 64)
self.fc2 = paddle.nn.Linear(64, 32)
self.fc3 = paddle.nn.Linear(32, 10)
self.fc4 = paddle.nn.Linear(10, 10)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)

return x


class PaddleDataset(Dataset):
def __init__(self):
super(PaddleDataset, self).__init__()
self.items = [paddle.rand((3, 4)) for i in range(320)]

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]


class TorchNet(torch.nn.Module):
def __init__(self):
super(TorchNet, self).__init__()

self.torch_fc1 = torch.nn.Linear(10, 10)
self.torch_softmax = torch.nn.Softmax(0)
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
self.torch_tensor = torch.ones(3, 3)
self.torch_param = torch.nn.Parameter(torch.ones(4, 4))


class TorchDataset(torch.utils.data.Dataset):
def __init__(self):
super(TorchDataset, self).__init__()
self.items = [torch.ones(3, 4) for i in range(320)]

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1

import torch
import paddle
from paddle.io import DataLoader

class PaddleDriverTestCase(unittest.TestCase):
class TestPaddleDriverFunctions:
"""
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试
PaddleDriver的测试类,使用仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试
"""

def setUp(self):
model = Net()
@classmethod
def setup_class(self):
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleDriver(model)

def test_check_single_optimizer_legacy(self):
def test_check_single_optimizer_legality(self):
"""
测试传入单个optimizer时的表现
"""
@@ -80,12 +33,12 @@ class PaddleDriverTestCase(unittest.TestCase):

self.driver.set_optimizers(optimizer)

optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01)
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
# 传入torch的optimizer时,应该报错ValueError
with self.assertRaises(ValueError) as cm:
self.driver.set_optimizers(optimizer)

def test_check_optimizers_legacy(self):
def test_check_optimizers_legality(self):
"""
测试传入optimizer list的表现
"""
@@ -99,22 +52,27 @@ class PaddleDriverTestCase(unittest.TestCase):
self.driver.set_optimizers(optimizers)

optimizers += [
torch.optim.Adam(TorchNet().parameters(), 0.01)
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
]

with self.assertRaises(ValueError) as cm:
self.driver.set_optimizers(optimizers)

def test_check_dataloader_legacy_in_train(self):
def test_check_dataloader_legality_in_train(self):
"""
测试is_train参数为True时,_check_dataloader_legality函数的表现
"""
dataloader = paddle.io.DataLoader(PaddleDataset())
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True)

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True)

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
TorchDataset(),
TorchNormalDataset(),
batch_size=32, shuffle=True
)
with self.assertRaises(ValueError) as cm:
@@ -125,21 +83,31 @@ class PaddleDriverTestCase(unittest.TestCase):
测试is_train参数为False时,_check_dataloader_legality函数的表现
"""
# 此时传入的应该是dict
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())}
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset())
}
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
}
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleDataset())
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 创建torch的dataloader
train_loader = torch.utils.data.DataLoader(
TorchDataset(),
TorchNormalDataset(),
batch_size=32, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
TorchDataset(),
TorchNormalDataset(),
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
@@ -240,7 +208,7 @@ class PaddleDriverTestCase(unittest.TestCase):
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_deterministic_dataloader(dataloader)

def test_set_sampler_epoch(self):
@@ -249,7 +217,7 @@ class PaddleDriverTestCase(unittest.TestCase):
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_sampler_epoch(dataloader, 0)

def test_get_dataloader_args(self):
@@ -258,5 +226,5 @@ class PaddleDriverTestCase(unittest.TestCase):
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
dataloader = DataLoader(PaddleNormalDataset())
res = PaddleDriver.get_dataloader_args(dataloader)

Loading…
Cancel
Save