Browse Source

提交tests/core/drivers/torch_paddle_driver

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
6170976b7f
3 changed files with 121 additions and 0 deletions
  1. +0
    -0
      tests/core/drivers/torch_paddle_driver/__init__.py
  2. +121
    -0
      tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py
  3. +0
    -0
      tests/core/drivers/torch_paddle_driver/test_utils.py

+ 0
- 0
tests/core/drivers/torch_paddle_driver/__init__.py View File


+ 121
- 0
tests/core/drivers/torch_paddle_driver/test_torch_paddle_driver.py View File

@@ -0,0 +1,121 @@
import unittest

from fastNLP.modules.mix_modules.mix_module import MixModule
from fastNLP.core.drivers.torch_paddle_driver.torch_paddle_driver import TorchPaddleDriver
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle

import torch
import paddle
from paddle.io import Dataset, DataLoader
import numpy as np

############################################################################
#
# 测试在MNIST数据集上的表现
#
############################################################################

class MNISTDataset(Dataset):
def __init__(self, dataset):

self.dataset = [
(
np.array(img).astype('float32').reshape(-1),
label
) for img, label in dataset
]

def __getitem__(self, idx):
return self.dataset[idx]

def __len__(self):
return len(self.dataset)

class MixMNISTModel(MixModule):
def __init__(self):
super(MixMNISTModel, self).__init__()

self.fc1 = paddle.nn.Linear(784, 64)
self.fc2 = paddle.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)
self.fc4 = torch.nn.Linear(10, 10)

def forward(self, x):

paddle_out = self.fc1(x)
paddle_out = self.fc2(paddle_out)
torch_in = paddle2torch(paddle_out)
torch_out = self.fc3(torch_in)
torch_out = self.fc4(torch_out)

return torch_out

def train_step(self, x):
return self.forward(x)

def test_step(self, x):
return self.forward(x)

class TestMNIST(unittest.TestCase):

@classmethod
def setUpClass(self):

self.train_dataset = paddle.vision.datasets.MNIST(mode='train')
self.test_dataset = paddle.vision.datasets.MNIST(mode='test')
self.train_dataset = MNISTDataset(self.train_dataset)

self.lr = 0.0003
self.epochs = 20

self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True)

def setUp(self):
model = MixMNISTModel()
self.torch_loss_func = torch.nn.CrossEntropyLoss()

torch_opt = torch.optim.Adam(model.parameters(backend="torch"), self.lr)
paddle_opt = paddle.optimizer.Adam(parameters=model.parameters(backend="paddle"), learning_rate=self.lr)

self.driver = TorchPaddleDriver(model=model, device="cuda:0")
self.driver.set_optimizers([torch_opt, paddle_opt])

def test_case1(self):

epochs = 20

self.driver.setup()
self.driver.zero_grad()
# 开始训练
current_epoch_idx = 0
while current_epoch_idx < epochs:
epoch_loss, batch = 0, 0
self.driver.set_model_mode("train")
self.driver.set_sampler_epoch(self.dataloader, current_epoch_idx)
for batch, (img, label) in enumerate(self.dataloader):
img = paddle.to_tensor(img).cuda()
torch_out = self.driver.train_step(img)
label = torch.from_numpy(label.numpy()).reshape(-1)
loss = self.torch_loss_func(torch_out.cpu(), label)
epoch_loss += loss.item()

self.driver.backward(loss)
self.driver.step()
self.driver.zero_grad()

current_epoch_idx += 1

# 开始测试
correct = 0
for img, label in self.test_dataset:

img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
torch_out = self.driver.test_step(img)
res = torch_out.softmax(-1).argmax().item()
label = label.item()
if res == label:
correct += 1

acc = correct / len(self.test_dataset)
self.assertGreater(acc, 0.85)

+ 0
- 0
tests/core/drivers/torch_paddle_driver/test_utils.py View File


Loading…
Cancel
Save