From 6170976b7f2aebe8e530a22068d01d9cfcff15eb Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 8 Apr 2022 12:14:33 +0000 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4tests/core/drivers/torch=5Fpa?= =?UTF-8?q?ddle=5Fdriver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/torch_paddle_driver/__init__.py | 0 .../test_torch_paddle_driver.py | 121 ++++++++++++++++++ .../drivers/torch_paddle_driver/test_utils.py | 0 3 files changed, 121 insertions(+) create mode 100644 tests/core/drivers/torch_paddle_driver/__init__.py create mode 100644 tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py create mode 100644 tests/core/drivers/torch_paddle_driver/test_utils.py diff --git a/tests/core/drivers/torch_paddle_driver/__init__.py b/tests/core/drivers/torch_paddle_driver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py b/tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py new file mode 100644 index 00000000..0f93161f --- /dev/null +++ b/tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py @@ -0,0 +1,121 @@ +import unittest + +from fastNLP.modules.mix_modules.mix_module import MixModule +from fastNLP.core.drivers.torch_paddle_driver.torch_paddle_driver import TorchPaddleDriver +from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle + +import torch +import paddle +from paddle.io import Dataset, DataLoader +import numpy as np + +############################################################################ +# +# 测试在MNIST数据集上的表现 +# +############################################################################ + +class MNISTDataset(Dataset): + def __init__(self, dataset): + + self.dataset = [ + ( + np.array(img).astype('float32').reshape(-1), + label + ) for img, label in dataset + ] + + def __getitem__(self, idx): + return self.dataset[idx] + + def __len__(self): + return len(self.dataset) + +class MixMNISTModel(MixModule): + def __init__(self): + super(MixMNISTModel, self).__init__() + + self.fc1 = paddle.nn.Linear(784, 64) + self.fc2 = paddle.nn.Linear(64, 32) + self.fc3 = torch.nn.Linear(32, 10) + self.fc4 = torch.nn.Linear(10, 10) + + def forward(self, x): + + paddle_out = self.fc1(x) + paddle_out = self.fc2(paddle_out) + torch_in = paddle2torch(paddle_out) + torch_out = self.fc3(torch_in) + torch_out = self.fc4(torch_out) + + return torch_out + + def train_step(self, x): + return self.forward(x) + + def test_step(self, x): + return self.forward(x) + +class TestMNIST(unittest.TestCase): + + @classmethod + def setUpClass(self): + + self.train_dataset = paddle.vision.datasets.MNIST(mode='train') + self.test_dataset = paddle.vision.datasets.MNIST(mode='test') + self.train_dataset = MNISTDataset(self.train_dataset) + + self.lr = 0.0003 + self.epochs = 20 + + self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True) + + def setUp(self): + + model = MixMNISTModel() + self.torch_loss_func = torch.nn.CrossEntropyLoss() + + torch_opt = torch.optim.Adam(model.parameters(backend="torch"), self.lr) + paddle_opt = paddle.optimizer.Adam(parameters=model.parameters(backend="paddle"), learning_rate=self.lr) + + self.driver = TorchPaddleDriver(model=model, device="cuda:0") + self.driver.set_optimizers([torch_opt, paddle_opt]) + + def test_case1(self): + + epochs = 20 + + self.driver.setup() + self.driver.zero_grad() + # 开始训练 + current_epoch_idx = 0 + while current_epoch_idx < epochs: + epoch_loss, batch = 0, 0 + self.driver.set_model_mode("train") + self.driver.set_sampler_epoch(self.dataloader, current_epoch_idx) + for batch, (img, label) in enumerate(self.dataloader): + img = paddle.to_tensor(img).cuda() + torch_out = self.driver.train_step(img) + label = torch.from_numpy(label.numpy()).reshape(-1) + loss = self.torch_loss_func(torch_out.cpu(), label) + epoch_loss += loss.item() + + self.driver.backward(loss) + self.driver.step() + self.driver.zero_grad() + + current_epoch_idx += 1 + + # 开始测试 + correct = 0 + for img, label in self.test_dataset: + + img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) + torch_out = self.driver.test_step(img) + res = torch_out.softmax(-1).argmax().item() + label = label.item() + if res == label: + correct += 1 + + acc = correct / len(self.test_dataset) + self.assertGreater(acc, 0.85) \ No newline at end of file diff --git a/tests/core/drivers/torch_paddle_driver/test_utils.py b/tests/core/drivers/torch_paddle_driver/test_utils.py new file mode 100644 index 00000000..e69de29b