|
|
@@ -0,0 +1,121 @@ |
|
|
|
import unittest |
|
|
|
|
|
|
|
from fastNLP.modules.mix_modules.mix_module import MixModule |
|
|
|
from fastNLP.core.drivers.torch_paddle_driver.torch_paddle_driver import TorchPaddleDriver |
|
|
|
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle |
|
|
|
|
|
|
|
import torch |
|
|
|
import paddle |
|
|
|
from paddle.io import Dataset, DataLoader |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试在MNIST数据集上的表现 |
|
|
|
# |
|
|
|
############################################################################ |
|
|
|
|
|
|
|
class MNISTDataset(Dataset): |
|
|
|
def __init__(self, dataset): |
|
|
|
|
|
|
|
self.dataset = [ |
|
|
|
( |
|
|
|
np.array(img).astype('float32').reshape(-1), |
|
|
|
label |
|
|
|
) for img, label in dataset |
|
|
|
] |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
return self.dataset[idx] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.dataset) |
|
|
|
|
|
|
|
class MixMNISTModel(MixModule): |
|
|
|
def __init__(self): |
|
|
|
super(MixMNISTModel, self).__init__() |
|
|
|
|
|
|
|
self.fc1 = paddle.nn.Linear(784, 64) |
|
|
|
self.fc2 = paddle.nn.Linear(64, 32) |
|
|
|
self.fc3 = torch.nn.Linear(32, 10) |
|
|
|
self.fc4 = torch.nn.Linear(10, 10) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
paddle_out = self.fc1(x) |
|
|
|
paddle_out = self.fc2(paddle_out) |
|
|
|
torch_in = paddle2torch(paddle_out) |
|
|
|
torch_out = self.fc3(torch_in) |
|
|
|
torch_out = self.fc4(torch_out) |
|
|
|
|
|
|
|
return torch_out |
|
|
|
|
|
|
|
def train_step(self, x): |
|
|
|
return self.forward(x) |
|
|
|
|
|
|
|
def test_step(self, x): |
|
|
|
return self.forward(x) |
|
|
|
|
|
|
|
class TestMNIST(unittest.TestCase): |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def setUpClass(self): |
|
|
|
|
|
|
|
self.train_dataset = paddle.vision.datasets.MNIST(mode='train') |
|
|
|
self.test_dataset = paddle.vision.datasets.MNIST(mode='test') |
|
|
|
self.train_dataset = MNISTDataset(self.train_dataset) |
|
|
|
|
|
|
|
self.lr = 0.0003 |
|
|
|
self.epochs = 20 |
|
|
|
|
|
|
|
self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True) |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
|
|
|
|
model = MixMNISTModel() |
|
|
|
self.torch_loss_func = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
torch_opt = torch.optim.Adam(model.parameters(backend="torch"), self.lr) |
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=model.parameters(backend="paddle"), learning_rate=self.lr) |
|
|
|
|
|
|
|
self.driver = TorchPaddleDriver(model=model, device="cuda:0") |
|
|
|
self.driver.set_optimizers([torch_opt, paddle_opt]) |
|
|
|
|
|
|
|
def test_case1(self): |
|
|
|
|
|
|
|
epochs = 20 |
|
|
|
|
|
|
|
self.driver.setup() |
|
|
|
self.driver.zero_grad() |
|
|
|
# 开始训练 |
|
|
|
current_epoch_idx = 0 |
|
|
|
while current_epoch_idx < epochs: |
|
|
|
epoch_loss, batch = 0, 0 |
|
|
|
self.driver.set_model_mode("train") |
|
|
|
self.driver.set_sampler_epoch(self.dataloader, current_epoch_idx) |
|
|
|
for batch, (img, label) in enumerate(self.dataloader): |
|
|
|
img = paddle.to_tensor(img).cuda() |
|
|
|
torch_out = self.driver.train_step(img) |
|
|
|
label = torch.from_numpy(label.numpy()).reshape(-1) |
|
|
|
loss = self.torch_loss_func(torch_out.cpu(), label) |
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
|
|
self.driver.backward(loss) |
|
|
|
self.driver.step() |
|
|
|
self.driver.zero_grad() |
|
|
|
|
|
|
|
current_epoch_idx += 1 |
|
|
|
|
|
|
|
# 开始测试 |
|
|
|
correct = 0 |
|
|
|
for img, label in self.test_dataset: |
|
|
|
|
|
|
|
img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) |
|
|
|
torch_out = self.driver.test_step(img) |
|
|
|
res = torch_out.softmax(-1).argmax().item() |
|
|
|
label = label.item() |
|
|
|
if res == label: |
|
|
|
correct += 1 |
|
|
|
|
|
|
|
acc = correct / len(self.test_dataset) |
|
|
|
self.assertGreater(acc, 0.85) |