From 05ca80bb6f50d50c51c720612bcaead95e6fe1b8 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 8 Apr 2022 12:15:33 +0000 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4tests/modules/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/modules/__init__.py | 0 tests/modules/mix_modules/__init__.py | 0 tests/modules/mix_modules/test_mix_module.py | 376 ++++++++++++++++ tests/modules/mix_modules/test_utils.py | 435 +++++++++++++++++++ 4 files changed, 811 insertions(+) create mode 100644 tests/modules/__init__.py create mode 100644 tests/modules/mix_modules/__init__.py create mode 100644 tests/modules/mix_modules/test_mix_module.py create mode 100644 tests/modules/mix_modules/test_utils.py diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/mix_modules/__init__.py b/tests/modules/mix_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/mix_modules/test_mix_module.py b/tests/modules/mix_modules/test_mix_module.py new file mode 100644 index 00000000..ae249c74 --- /dev/null +++ b/tests/modules/mix_modules/test_mix_module.py @@ -0,0 +1,376 @@ +import unittest +import os +from itertools import chain + +import torch +import paddle +from paddle.io import Dataset, DataLoader +import numpy as np + +from fastNLP.modules.mix_modules.mix_module import MixModule +from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle +from fastNLP.core import synchronize_safe_rm + + +############################################################################ +# +# 测试类的基本功能 +# +############################################################################ + +class TestMixModule(MixModule): + def __init__(self): + super(TestMixModule, self).__init__() + + self.torch_fc1 = torch.nn.Linear(10, 10) + self.torch_softmax = torch.nn.Softmax(0) + self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) + self.torch_tensor = torch.ones(3, 3) + self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) + + self.paddle_fc1 = paddle.nn.Linear(10, 10) + self.paddle_softmax = paddle.nn.Softmax(0) + self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3) + self.paddle_tensor = paddle.ones((4, 4)) + +class TestTorchModule(torch.nn.Module): + def __init__(self): + super(TestTorchModule, self).__init__() + + self.torch_fc1 = torch.nn.Linear(10, 10) + self.torch_softmax = torch.nn.Softmax(0) + self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) + self.torch_tensor = torch.ones(3, 3) + self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) + +class TestPaddleModule(paddle.nn.Layer): + def __init__(self): + super(TestPaddleModule, self).__init__() + + self.paddle_fc1 = paddle.nn.Linear(10, 10) + self.paddle_softmax = paddle.nn.Softmax(0) + self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3) + self.paddle_tensor = paddle.ones((4, 4)) + + +class TorchPaddleMixModuleTestCase(unittest.TestCase): + + def setUp(self): + + self.model = TestMixModule() + self.torch_model = TestTorchModule() + self.paddle_model = TestPaddleModule() + + def test_to(self): + """ + 测试混合模型的to函数 + """ + + self.model.to("cuda") + self.torch_model.to("cuda") + self.paddle_model.to("gpu") + self.if_device_correct("cuda") + + self.model.to("cuda:2") + self.torch_model.to("cuda:2") + self.paddle_model.to("gpu:2") + self.if_device_correct("cuda:2") + + self.model.to("gpu:1") + self.torch_model.to("cuda:1") + self.paddle_model.to("gpu:1") + self.if_device_correct("cuda:1") + + self.model.to("cpu") + self.torch_model.to("cpu") + self.paddle_model.to("cpu") + self.if_device_correct("cpu") + + def test_train_eval(self): + """ + 测试train和eval函数 + """ + + self.model.eval() + self.if_training_correct(False) + + self.model.train() + self.if_training_correct(True) + + def test_parameters(self): + """ + 测试parameters()函数,由于初始化是随机的,目前仅比较得到结果的长度 + """ + mix_params = [] + params = [] + + for value in self.model.named_parameters(): + mix_params.append(value) + + for value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()): + params.append(value) + + self.assertEqual(len(params), len(mix_params)) + + def test_named_parameters(self): + """ + 测试named_parameters函数 + """ + + mix_param_names = [] + param_names = [] + + for name, value in self.model.named_parameters(): + mix_param_names.append(name) + + for name, value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()): + param_names.append(name) + + self.assertListEqual(sorted(param_names), sorted(mix_param_names)) + + def test_torch_named_parameters(self): + """ + 测试对torch参数的提取 + """ + + mix_param_names = [] + param_names = [] + + for name, value in self.model.named_parameters(backend="torch"): + mix_param_names.append(name) + + for name, value in self.torch_model.named_parameters(): + param_names.append(name) + + self.assertListEqual(sorted(param_names), sorted(mix_param_names)) + + def test_paddle_named_parameters(self): + """ + 测试对paddle参数的提取 + """ + + mix_param_names = [] + param_names = [] + + for name, value in self.model.named_parameters(backend="paddle"): + mix_param_names.append(name) + + for name, value in self.paddle_model.named_parameters(): + param_names.append(name) + + self.assertListEqual(sorted(param_names), sorted(mix_param_names)) + + def test_torch_state_dict(self): + """ + 测试提取torch的state dict + """ + torch_dict = self.torch_model.state_dict() + mix_dict = self.model.state_dict(backend="torch") + + self.assertListEqual(sorted(torch_dict.keys()), sorted(mix_dict.keys())) + + def test_paddle_state_dict(self): + """ + 测试提取paddle的state dict + """ + paddle_dict = self.paddle_model.state_dict() + mix_dict = self.model.state_dict(backend="paddle") + + # TODO 测试程序会显示passed后显示paddle的异常退出信息 + self.assertListEqual(sorted(paddle_dict.keys()), sorted(mix_dict.keys())) + + def test_state_dict(self): + """ + 测试提取所有的state dict + """ + all_dict = self.torch_model.state_dict() + all_dict.update(self.paddle_model.state_dict()) + mix_dict = self.model.state_dict() + + # TODO 测试程序会显示passed后显示paddle的异常退出信息 + self.assertListEqual(sorted(all_dict.keys()), sorted(mix_dict.keys())) + + def test_load_state_dict(self): + """ + 测试load_state_dict函数 + """ + state_dict = self.model.state_dict() + + new_model = TestMixModule() + new_model.load_state_dict(state_dict) + new_state_dict = new_model.state_dict() + + for name, value in state_dict.items(): + state_dict[name] = value.tolist() + for name, value in new_state_dict.items(): + new_state_dict[name] = value.tolist() + + self.assertDictEqual(state_dict, new_state_dict) + + def test_save_and_load_state_dict(self): + """ + 测试save_state_dict_to_file和load_state_dict_from_file函数 + """ + path = "model" + try: + self.model.save_state_dict_to_file(path) + new_model = TestMixModule() + new_model.load_state_dict_from_file(path) + + state_dict = self.model.state_dict() + new_state_dict = new_model.state_dict() + + for name, value in state_dict.items(): + state_dict[name] = value.tolist() + for name, value in new_state_dict.items(): + new_state_dict[name] = value.tolist() + + self.assertDictEqual(state_dict, new_state_dict) + finally: + synchronize_safe_rm(path) + + def if_device_correct(self, device): + + + self.assertEqual(self.model.torch_fc1.weight.device, self.torch_model.torch_fc1.weight.device) + self.assertEqual(self.model.torch_conv2d1.weight.device, self.torch_model.torch_fc1.bias.device) + self.assertEqual(self.model.torch_conv2d1.bias.device, self.torch_model.torch_conv2d1.bias.device) + self.assertEqual(self.model.torch_tensor.device, self.torch_model.torch_tensor.device) + self.assertEqual(self.model.torch_param.device, self.torch_model.torch_param.device) + + if device == "cpu": + self.assertTrue(self.model.paddle_fc1.weight.place.is_cpu_place()) + self.assertTrue(self.model.paddle_fc1.bias.place.is_cpu_place()) + self.assertTrue(self.model.paddle_conv2d1.weight.place.is_cpu_place()) + self.assertTrue(self.model.paddle_conv2d1.bias.place.is_cpu_place()) + self.assertTrue(self.model.paddle_tensor.place.is_cpu_place()) + elif device.startswith("cuda"): + self.assertTrue(self.model.paddle_fc1.weight.place.is_gpu_place()) + self.assertTrue(self.model.paddle_fc1.bias.place.is_gpu_place()) + self.assertTrue(self.model.paddle_conv2d1.weight.place.is_gpu_place()) + self.assertTrue(self.model.paddle_conv2d1.bias.place.is_gpu_place()) + self.assertTrue(self.model.paddle_tensor.place.is_gpu_place()) + + self.assertEqual(self.model.paddle_fc1.weight.place.gpu_device_id(), self.paddle_model.paddle_fc1.weight.place.gpu_device_id()) + self.assertEqual(self.model.paddle_fc1.bias.place.gpu_device_id(), self.paddle_model.paddle_fc1.bias.place.gpu_device_id()) + self.assertEqual(self.model.paddle_conv2d1.weight.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.weight.place.gpu_device_id()) + self.assertEqual(self.model.paddle_conv2d1.bias.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.bias.place.gpu_device_id()) + self.assertEqual(self.model.paddle_tensor.place.gpu_device_id(), self.paddle_model.paddle_tensor.place.gpu_device_id()) + else: + raise NotImplementedError + + def if_training_correct(self, training): + + self.assertEqual(self.model.torch_fc1.training, training) + self.assertEqual(self.model.torch_softmax.training, training) + self.assertEqual(self.model.torch_conv2d1.training, training) + + self.assertEqual(self.model.paddle_fc1.training, training) + self.assertEqual(self.model.paddle_softmax.training, training) + self.assertEqual(self.model.paddle_conv2d1.training, training) + + +############################################################################ +# +# 测试在MNIST数据集上的表现 +# +############################################################################ + +class MNISTDataset(Dataset): + def __init__(self, dataset): + + self.dataset = [ + ( + np.array(img).astype('float32').reshape(-1), + label + ) for img, label in dataset + ] + + def __getitem__(self, idx): + return self.dataset[idx] + + def __len__(self): + return len(self.dataset) + +class MixMNISTModel(MixModule): + def __init__(self): + super(MixMNISTModel, self).__init__() + + self.fc1 = paddle.nn.Linear(784, 64) + self.fc2 = paddle.nn.Linear(64, 32) + self.fc3 = torch.nn.Linear(32, 10) + self.fc4 = torch.nn.Linear(10, 10) + + def forward(self, x): + + paddle_out = self.fc1(x) + paddle_out = self.fc2(paddle_out) + torch_in = paddle2torch(paddle_out) + torch_out = self.fc3(torch_in) + torch_out = self.fc4(torch_out) + + return torch_out + +class TestMNIST(unittest.TestCase): + + @classmethod + def setUpClass(self): + + self.train_dataset = paddle.vision.datasets.MNIST(mode='train') + self.test_dataset = paddle.vision.datasets.MNIST(mode='test') + self.train_dataset = MNISTDataset(self.train_dataset) + + self.lr = 0.0003 + self.epochs = 20 + + self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True) + + def setUp(self): + + self.model = MixMNISTModel().to("cuda") + self.torch_loss_func = torch.nn.CrossEntropyLoss() + + self.torch_opt = torch.optim.Adam(self.model.parameters(backend="torch"), self.lr) + self.paddle_opt = paddle.optimizer.Adam(parameters=self.model.parameters(backend="paddle"), learning_rate=self.lr) + + def test_case1(self): + + # 开始训练 + for epoch in range(self.epochs): + epoch_loss, batch = 0, 0 + for batch, (img, label) in enumerate(self.dataloader): + + img = paddle.to_tensor(img).cuda() + torch_out = self.model(img) + label = torch.from_numpy(label.numpy()).reshape(-1) + loss = self.torch_loss_func(torch_out.cpu(), label) + epoch_loss += loss.item() + + loss.backward() + self.torch_opt.step() + self.paddle_opt.step() + self.torch_opt.zero_grad() + self.paddle_opt.clear_grad() + + else: + self.assertLess(epoch_loss / (batch + 1), 0.3) + + # 开始测试 + correct = 0 + for img, label in self.test_dataset: + + img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) + torch_out = self.model(img) + res = torch_out.softmax(-1).argmax().item() + label = label.item() + if res == label: + correct += 1 + + acc = correct / len(self.test_dataset) + self.assertGreater(acc, 0.85) + +############################################################################ +# +# 测试在ERNIE中文数据集上的表现 +# +############################################################################ diff --git a/tests/modules/mix_modules/test_utils.py b/tests/modules/mix_modules/test_utils.py new file mode 100644 index 00000000..92d0580b --- /dev/null +++ b/tests/modules/mix_modules/test_utils.py @@ -0,0 +1,435 @@ +import unittest +import os + +os.environ["log_silent"] = "1" +import torch +import paddle +import jittor + +from fastNLP.modules.mix_modules.utils import ( + paddle2torch, + torch2paddle, + jittor2torch, + torch2jittor, +) + +############################################################################ +# +# 测试paddle到torch的转换 +# +############################################################################ + +class Paddle2TorchTestCase(unittest.TestCase): + + def check_torch_tensor(self, tensor, device, requires_grad): + """ + 检查张量设备和梯度情况的工具函数 + """ + + self.assertIsInstance(tensor, torch.Tensor) + self.assertEqual(tensor.device, torch.device(device)) + self.assertEqual(tensor.requires_grad, requires_grad) + + def test_gradient(self): + """ + 测试张量转换后的反向传播是否正确 + """ + + x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0, 5.0], stop_gradient=False) + y = paddle2torch(x) + z = 3 * (y ** 2) + z.sum().backward() + self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30]) + + def test_tensor_transfer(self): + """ + 测试单个张量的设备和梯度转换是否正确 + """ + + paddle_tensor = paddle.rand((3, 4, 5)).cpu() + res = paddle2torch(paddle_tensor) + self.check_torch_tensor(res, "cpu", not paddle_tensor.stop_gradient) + + res = paddle2torch(paddle_tensor, target_device="cuda:2", no_gradient=None) + self.check_torch_tensor(res, "cuda:2", not paddle_tensor.stop_gradient) + + res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=True) + self.check_torch_tensor(res, "cuda:1", False) + + res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=False) + self.check_torch_tensor(res, "cuda:1", True) + + def test_list_transfer(self): + """ + 测试张量列表的转换 + """ + + paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)] + res = paddle2torch(paddle_list) + self.assertIsInstance(res, list) + for t in res: + self.check_torch_tensor(t, "cuda:1", False) + + res = paddle2torch(paddle_list, target_device="cpu", no_gradient=False) + self.assertIsInstance(res, list) + for t in res: + self.check_torch_tensor(t, "cpu", True) + + def test_tensor_tuple_transfer(self): + """ + 测试张量元组的转换 + """ + + paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)] + paddle_tuple = tuple(paddle_list) + res = paddle2torch(paddle_tuple) + self.assertIsInstance(res, tuple) + for t in res: + self.check_torch_tensor(t, "cuda:1", False) + + def test_dict_transfer(self): + """ + 测试包含复杂结构的字典的转换 + """ + + paddle_dict = { + "tensor": paddle.rand((3, 4)).cuda(0), + "list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)], + "dict":{ + "list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)], + "tensor": paddle.rand((3, 4)).cuda(0) + }, + "int": 2, + "string": "test string" + } + res = paddle2torch(paddle_dict) + self.assertIsInstance(res, dict) + self.check_torch_tensor(res["tensor"], "cuda:0", False) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_torch_tensor(t, "cuda:0", False) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_torch_tensor(t, "cuda:0", False) + self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", False) + + +############################################################################ +# +# 测试torch到paddle的转换 +# +############################################################################ + +class Torch2PaddleTestCase(unittest.TestCase): + + def check_paddle_tensor(self, tensor, device, stop_gradient): + """ + 检查得到的paddle张量设备和梯度情况的工具函数 + """ + + self.assertIsInstance(tensor, paddle.Tensor) + if device == "cpu": + self.assertTrue(tensor.place.is_cpu_place()) + elif device.startswith("gpu"): + paddle_device = paddle.device._convert_to_place(device) + self.assertTrue(tensor.place.is_gpu_place()) + if hasattr(tensor.place, "gpu_device_id"): + # paddle中,有两种Place + # paddle.fluid.core.Place是创建Tensor时使用的类型 + # 有函数gpu_device_id获取设备 + self.assertEqual(tensor.place.gpu_device_id(), paddle_device.get_device_id()) + else: + # 通过_convert_to_place得到的是paddle.CUDAPlace + # 通过get_device_id获取设备 + self.assertEqual(tensor.place.get_device_id(), paddle_device.get_device_id()) + else: + raise NotImplementedError + self.assertEqual(tensor.stop_gradient, stop_gradient) + + def test_gradient(self): + """ + 测试转换后梯度的反向传播 + """ + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) + y = torch2paddle(x) + z = 3 * (y ** 2) + z.sum().backward() + self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30]) + + def test_tensor_transfer(self): + """ + 测试单个张量的转换 + """ + + torch_tensor = torch.rand((3, 4, 5)) + res = torch2paddle(torch_tensor) + self.check_paddle_tensor(res, "cpu", True) + + res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=None) + self.check_paddle_tensor(res, "gpu:2", True) + + res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=True) + self.check_paddle_tensor(res, "gpu:2", True) + + res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=False) + self.check_paddle_tensor(res, "gpu:2", False) + + def test_tensor_list_transfer(self): + """ + 测试张量列表的转换 + """ + + torch_list = [torch.rand(6, 4, 2) for i in range(10)] + res = torch2paddle(torch_list) + self.assertIsInstance(res, list) + for t in res: + self.check_paddle_tensor(t, "cpu", True) + + res = torch2paddle(torch_list, target_device="gpu:1", no_gradient=False) + self.assertIsInstance(res, list) + for t in res: + self.check_paddle_tensor(t, "gpu:1", False) + + def test_tensor_tuple_transfer(self): + """ + 测试张量元组的转换 + """ + + torch_list = [torch.rand(6, 4, 2) for i in range(10)] + torch_tuple = tuple(torch_list) + res = torch2paddle(torch_tuple, target_device="cpu") + self.assertIsInstance(res, tuple) + for t in res: + self.check_paddle_tensor(t, "cpu", True) + + def test_dict_transfer(self): + """ + 测试复杂的字典结构的转换 + """ + + torch_dict = { + "tensor": torch.rand((3, 4)), + "list": [torch.rand(6, 4, 2) for i in range(10)], + "dict":{ + "list": [torch.rand(6, 4, 2) for i in range(10)], + "tensor": torch.rand((3, 4)) + }, + "int": 2, + "string": "test string" + } + res = torch2paddle(torch_dict) + self.assertIsInstance(res, dict) + self.check_paddle_tensor(res["tensor"], "cpu", True) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_paddle_tensor(t, "cpu", True) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_paddle_tensor(t, "cpu", True) + self.check_paddle_tensor(res["dict"]["tensor"], "cpu", True) + + +############################################################################ +# +# 测试jittor到torch的转换 +# +############################################################################ + +class Jittor2TorchTestCase(unittest.TestCase): + + def check_torch_tensor(self, tensor, device, requires_grad): + """ + 检查得到的torch张量的工具函数 + """ + + self.assertIsInstance(tensor, torch.Tensor) + if device == "cpu": + self.assertFalse(tensor.is_cuda) + else: + self.assertEqual(tensor.device, torch.device(device)) + self.assertEqual(tensor.requires_grad, requires_grad) + + def test_var_transfer(self): + """ + 测试单个Jittor Var的转换 + """ + + jittor_var = jittor.rand((3, 4, 5)) + res = jittor2torch(jittor_var) + self.check_torch_tensor(res, "cpu", True) + + res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=None) + self.check_torch_tensor(res, "cuda:2", True) + + res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=True) + self.check_torch_tensor(res, "cuda:2", False) + + res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=False) + self.check_torch_tensor(res, "cuda:2", True) + + def test_var_list_transfer(self): + """ + 测试Jittor列表的转换 + """ + + jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)] + res = jittor2torch(jittor_list) + self.assertIsInstance(res, list) + for t in res: + self.check_torch_tensor(t, "cpu", True) + + res = jittor2torch(jittor_list, target_device="cuda:1", no_gradient=False) + self.assertIsInstance(res, list) + for t in res: + self.check_torch_tensor(t, "cuda:1", True) + + def test_var_tuple_transfer(self): + """ + 测试Jittor变量元组的转换 + """ + + jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)] + jittor_tuple = tuple(jittor_list) + res = jittor2torch(jittor_tuple, target_device="cpu") + self.assertIsInstance(res, tuple) + for t in res: + self.check_torch_tensor(t, "cpu", True) + + def test_dict_transfer(self): + """ + 测试字典结构的转换 + """ + + jittor_dict = { + "tensor": jittor.rand((3, 4)), + "list": [jittor.rand(6, 4, 2) for i in range(10)], + "dict":{ + "list": [jittor.rand(6, 4, 2) for i in range(10)], + "tensor": jittor.rand((3, 4)) + }, + "int": 2, + "string": "test string" + } + res = jittor2torch(jittor_dict) + self.assertIsInstance(res, dict) + self.check_torch_tensor(res["tensor"], "cpu", True) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_torch_tensor(t, "cpu", True) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_torch_tensor(t, "cpu", True) + self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) + + +############################################################################ +# +# 测试torch到jittor的转换 +# +############################################################################ + +class Torch2JittorTestCase(unittest.TestCase): + + def check_jittor_var(self, var, requires_grad): + """ + 检查得到的Jittor Var梯度情况的工具函数 + """ + + self.assertIsInstance(var, jittor.Var) + self.assertEqual(var.requires_grad, requires_grad) + + def test_gradient(self): + """ + 测试反向传播的梯度 + """ + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) + y = torch2jittor(x) + z = 3 * (y ** 2) + grad = jittor.grad(z, y) + self.assertListEqual(grad.tolist(), [6.0, 12.0, 18.0, 24.0, 30.0]) + + def test_tensor_transfer(self): + """ + 测试单个张量转换为Jittor + """ + + torch_tensor = torch.rand((3, 4, 5)) + res = torch2jittor(torch_tensor) + self.check_jittor_var(res, False) + + res = torch2jittor(torch_tensor, no_gradient=None) + self.check_jittor_var(res, False) + + res = torch2jittor(torch_tensor, no_gradient=True) + self.check_jittor_var(res, False) + + res = torch2jittor(torch_tensor, no_gradient=False) + self.check_jittor_var(res, True) + + def test_tensor_list_transfer(self): + """ + 测试张量列表的转换 + """ + + torch_list = [torch.rand((6, 4, 2)) for i in range(10)] + res = torch2jittor(torch_list) + self.assertIsInstance(res, list) + for t in res: + self.check_jittor_var(t, False) + + res = torch2jittor(torch_list, no_gradient=False) + self.assertIsInstance(res, list) + for t in res: + self.check_jittor_var(t, True) + + def test_tensor_tuple_transfer(self): + """ + 测试张量元组的转换 + """ + + torch_list = [torch.rand((6, 4, 2)) for i in range(10)] + torch_tuple = tuple(torch_list) + res = torch2jittor(torch_tuple) + self.assertIsInstance(res, tuple) + for t in res: + self.check_jittor_var(t, False) + + def test_dict_transfer(self): + """ + 测试字典结构的转换 + """ + + torch_dict = { + "tensor": torch.rand((3, 4)), + "list": [torch.rand(6, 4, 2) for i in range(10)], + "dict":{ + "list": [torch.rand(6, 4, 2) for i in range(10)], + "tensor": torch.rand((3, 4)) + }, + "int": 2, + "string": "test string" + } + res = torch2jittor(torch_dict) + self.assertIsInstance(res, dict) + self.check_jittor_var(res["tensor"], False) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_jittor_var(t, False) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_jittor_var(t, False) + self.check_jittor_var(res["dict"]["tensor"], False)