diff --git a/tests/core/drivers/paddle_driver/test_paddle_driver.py b/tests/core/drivers/paddle_driver/test_paddle_driver.py index 9febc27d..9308785a 100644 --- a/tests/core/drivers/paddle_driver/test_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_paddle_driver.py @@ -1,75 +1,28 @@ -import unittest - -import torch +import os +import pytest +os.environ["FASTNLP_BACKEND"] = "paddle" from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver -import paddle -from paddle.io import Dataset, DataLoader - -class Net(paddle.nn.Layer): - def __init__(self): - super(Net, self).__init__() - - self.fc1 = paddle.nn.Linear(784, 64) - self.fc2 = paddle.nn.Linear(64, 32) - self.fc3 = paddle.nn.Linear(32, 10) - self.fc4 = paddle.nn.Linear(10, 10) - - def forward(self, x): - - x = self.fc1(x) - x = self.fc2(x) - x = self.fc3(x) - x = self.fc4(x) - - return x - - -class PaddleDataset(Dataset): - def __init__(self): - super(PaddleDataset, self).__init__() - self.items = [paddle.rand((3, 4)) for i in range(320)] - - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - return self.items[idx] - - -class TorchNet(torch.nn.Module): - def __init__(self): - super(TorchNet, self).__init__() - - self.torch_fc1 = torch.nn.Linear(10, 10) - self.torch_softmax = torch.nn.Softmax(0) - self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) - self.torch_tensor = torch.ones(3, 3) - self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) - - -class TorchDataset(torch.utils.data.Dataset): - def __init__(self): - super(TorchDataset, self).__init__() - self.items = [torch.ones(3, 4) for i in range(320)] - - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - return self.items[idx] +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.datasets.paddle_data import PaddleNormalDataset +from tests.helpers.datasets.torch_data import TorchNormalDataset +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +import torch +import paddle +from paddle.io import DataLoader -class PaddleDriverTestCase(unittest.TestCase): +class TestPaddleDriverFunctions: """ - PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 + PaddleDriver的测试类,使用仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 """ - def setUp(self): - model = Net() + @classmethod + def setup_class(self): + model = PaddleNormalModel_Classification_1(10, 32) self.driver = PaddleDriver(model) - def test_check_single_optimizer_legacy(self): + def test_check_single_optimizer_legality(self): """ 测试传入单个optimizer时的表现 """ @@ -80,12 +33,12 @@ class PaddleDriverTestCase(unittest.TestCase): self.driver.set_optimizers(optimizer) - optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) + optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) # 传入torch的optimizer时,应该报错ValueError with self.assertRaises(ValueError) as cm: self.driver.set_optimizers(optimizer) - def test_check_optimizers_legacy(self): + def test_check_optimizers_legality(self): """ 测试传入optimizer list的表现 """ @@ -99,22 +52,27 @@ class PaddleDriverTestCase(unittest.TestCase): self.driver.set_optimizers(optimizers) optimizers += [ - torch.optim.Adam(TorchNet().parameters(), 0.01) + torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) ] with self.assertRaises(ValueError) as cm: self.driver.set_optimizers(optimizers) - def test_check_dataloader_legacy_in_train(self): + def test_check_dataloader_legality_in_train(self): """ 测试is_train参数为True时,_check_dataloader_legality函数的表现 """ - dataloader = paddle.io.DataLoader(PaddleDataset()) + dataloader = paddle.io.DataLoader(PaddleNormalDataset()) PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) + # batch_size 和 batch_sampler 均为 None 的情形 + dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + with self.assertRaises(ValueError) as cm: + PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) + # 创建torch的dataloader dataloader = torch.utils.data.DataLoader( - TorchDataset(), + TorchNormalDataset(), batch_size=32, shuffle=True ) with self.assertRaises(ValueError) as cm: @@ -125,21 +83,31 @@ class PaddleDriverTestCase(unittest.TestCase): 测试is_train参数为False时,_check_dataloader_legality函数的表现 """ # 此时传入的应该是dict - dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} + dataloader = { + "train": paddle.io.DataLoader(PaddleNormalDataset()), + "test":paddle.io.DataLoader(PaddleNormalDataset()) + } + PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) + + # batch_size 和 batch_sampler 均为 None 的情形 + dataloader = { + "train": paddle.io.DataLoader(PaddleNormalDataset()), + "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + } PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) # 传入的不是dict,应该报错 - dataloader = paddle.io.DataLoader(PaddleDataset()) + dataloader = paddle.io.DataLoader(PaddleNormalDataset()) with self.assertRaises(ValueError) as cm: PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) # 创建torch的dataloader train_loader = torch.utils.data.DataLoader( - TorchDataset(), + TorchNormalDataset(), batch_size=32, shuffle=True ) test_loader = torch.utils.data.DataLoader( - TorchDataset(), + TorchNormalDataset(), batch_size=32, shuffle=True ) dataloader = {"train": train_loader, "test": test_loader} @@ -240,7 +208,7 @@ class PaddleDriverTestCase(unittest.TestCase): """ # 先确保不影响运行 # TODO:正确性 - dataloader = DataLoader(PaddleDataset()) + dataloader = DataLoader(PaddleNormalDataset()) self.driver.set_deterministic_dataloader(dataloader) def test_set_sampler_epoch(self): @@ -249,7 +217,7 @@ class PaddleDriverTestCase(unittest.TestCase): """ # 先确保不影响运行 # TODO:正确性 - dataloader = DataLoader(PaddleDataset()) + dataloader = DataLoader(PaddleNormalDataset()) self.driver.set_sampler_epoch(dataloader, 0) def test_get_dataloader_args(self): @@ -258,5 +226,5 @@ class PaddleDriverTestCase(unittest.TestCase): """ # 先确保不影响运行 # TODO:正确性 - dataloader = DataLoader(PaddleDataset()) + dataloader = DataLoader(PaddleNormalDataset()) res = PaddleDriver.get_dataloader_args(dataloader) \ No newline at end of file