Browse Source

提交tests/modules/

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
05ca80bb6f
4 changed files with 811 additions and 0 deletions
  1. +0
    -0
      tests/modules/__init__.py
  2. +0
    -0
      tests/modules/mix_modules/__init__.py
  3. +376
    -0
      tests/modules/mix_modules/test_mix_module.py
  4. +435
    -0
      tests/modules/mix_modules/test_utils.py

+ 0
- 0
tests/modules/__init__.py View File


+ 0
- 0
tests/modules/mix_modules/__init__.py View File


+ 376
- 0
tests/modules/mix_modules/test_mix_module.py View File

@@ -0,0 +1,376 @@
import unittest
import os
from itertools import chain

import torch
import paddle
from paddle.io import Dataset, DataLoader
import numpy as np

from fastNLP.modules.mix_modules.mix_module import MixModule
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle
from fastNLP.core import synchronize_safe_rm


############################################################################
#
# 测试类的基本功能
#
############################################################################

class TestMixModule(MixModule):
def __init__(self):
super(TestMixModule, self).__init__()

self.torch_fc1 = torch.nn.Linear(10, 10)
self.torch_softmax = torch.nn.Softmax(0)
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
self.torch_tensor = torch.ones(3, 3)
self.torch_param = torch.nn.Parameter(torch.ones(4, 4))

self.paddle_fc1 = paddle.nn.Linear(10, 10)
self.paddle_softmax = paddle.nn.Softmax(0)
self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
self.paddle_tensor = paddle.ones((4, 4))

class TestTorchModule(torch.nn.Module):
def __init__(self):
super(TestTorchModule, self).__init__()

self.torch_fc1 = torch.nn.Linear(10, 10)
self.torch_softmax = torch.nn.Softmax(0)
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
self.torch_tensor = torch.ones(3, 3)
self.torch_param = torch.nn.Parameter(torch.ones(4, 4))

class TestPaddleModule(paddle.nn.Layer):
def __init__(self):
super(TestPaddleModule, self).__init__()

self.paddle_fc1 = paddle.nn.Linear(10, 10)
self.paddle_softmax = paddle.nn.Softmax(0)
self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
self.paddle_tensor = paddle.ones((4, 4))


class TorchPaddleMixModuleTestCase(unittest.TestCase):

def setUp(self):

self.model = TestMixModule()
self.torch_model = TestTorchModule()
self.paddle_model = TestPaddleModule()

def test_to(self):
"""
测试混合模型的to函数
"""
self.model.to("cuda")
self.torch_model.to("cuda")
self.paddle_model.to("gpu")
self.if_device_correct("cuda")

self.model.to("cuda:2")
self.torch_model.to("cuda:2")
self.paddle_model.to("gpu:2")
self.if_device_correct("cuda:2")

self.model.to("gpu:1")
self.torch_model.to("cuda:1")
self.paddle_model.to("gpu:1")
self.if_device_correct("cuda:1")

self.model.to("cpu")
self.torch_model.to("cpu")
self.paddle_model.to("cpu")
self.if_device_correct("cpu")

def test_train_eval(self):
"""
测试train和eval函数
"""
self.model.eval()
self.if_training_correct(False)

self.model.train()
self.if_training_correct(True)

def test_parameters(self):
"""
测试parameters()函数,由于初始化是随机的,目前仅比较得到结果的长度
"""
mix_params = []
params = []

for value in self.model.named_parameters():
mix_params.append(value)

for value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
params.append(value)

self.assertEqual(len(params), len(mix_params))

def test_named_parameters(self):
"""
测试named_parameters函数
"""
mix_param_names = []
param_names = []

for name, value in self.model.named_parameters():
mix_param_names.append(name)

for name, value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
param_names.append(name)

self.assertListEqual(sorted(param_names), sorted(mix_param_names))

def test_torch_named_parameters(self):
"""
测试对torch参数的提取
"""
mix_param_names = []
param_names = []

for name, value in self.model.named_parameters(backend="torch"):
mix_param_names.append(name)

for name, value in self.torch_model.named_parameters():
param_names.append(name)

self.assertListEqual(sorted(param_names), sorted(mix_param_names))

def test_paddle_named_parameters(self):
"""
测试对paddle参数的提取
"""
mix_param_names = []
param_names = []

for name, value in self.model.named_parameters(backend="paddle"):
mix_param_names.append(name)

for name, value in self.paddle_model.named_parameters():
param_names.append(name)

self.assertListEqual(sorted(param_names), sorted(mix_param_names))

def test_torch_state_dict(self):
"""
测试提取torch的state dict
"""
torch_dict = self.torch_model.state_dict()
mix_dict = self.model.state_dict(backend="torch")

self.assertListEqual(sorted(torch_dict.keys()), sorted(mix_dict.keys()))

def test_paddle_state_dict(self):
"""
测试提取paddle的state dict
"""
paddle_dict = self.paddle_model.state_dict()
mix_dict = self.model.state_dict(backend="paddle")

# TODO 测试程序会显示passed后显示paddle的异常退出信息
self.assertListEqual(sorted(paddle_dict.keys()), sorted(mix_dict.keys()))

def test_state_dict(self):
"""
测试提取所有的state dict
"""
all_dict = self.torch_model.state_dict()
all_dict.update(self.paddle_model.state_dict())
mix_dict = self.model.state_dict()

# TODO 测试程序会显示passed后显示paddle的异常退出信息
self.assertListEqual(sorted(all_dict.keys()), sorted(mix_dict.keys()))

def test_load_state_dict(self):
"""
测试load_state_dict函数
"""
state_dict = self.model.state_dict()

new_model = TestMixModule()
new_model.load_state_dict(state_dict)
new_state_dict = new_model.state_dict()

for name, value in state_dict.items():
state_dict[name] = value.tolist()
for name, value in new_state_dict.items():
new_state_dict[name] = value.tolist()

self.assertDictEqual(state_dict, new_state_dict)

def test_save_and_load_state_dict(self):
"""
测试save_state_dict_to_file和load_state_dict_from_file函数
"""
path = "model"
try:
self.model.save_state_dict_to_file(path)
new_model = TestMixModule()
new_model.load_state_dict_from_file(path)

state_dict = self.model.state_dict()
new_state_dict = new_model.state_dict()

for name, value in state_dict.items():
state_dict[name] = value.tolist()
for name, value in new_state_dict.items():
new_state_dict[name] = value.tolist()

self.assertDictEqual(state_dict, new_state_dict)
finally:
synchronize_safe_rm(path)

def if_device_correct(self, device):


self.assertEqual(self.model.torch_fc1.weight.device, self.torch_model.torch_fc1.weight.device)
self.assertEqual(self.model.torch_conv2d1.weight.device, self.torch_model.torch_fc1.bias.device)
self.assertEqual(self.model.torch_conv2d1.bias.device, self.torch_model.torch_conv2d1.bias.device)
self.assertEqual(self.model.torch_tensor.device, self.torch_model.torch_tensor.device)
self.assertEqual(self.model.torch_param.device, self.torch_model.torch_param.device)

if device == "cpu":
self.assertTrue(self.model.paddle_fc1.weight.place.is_cpu_place())
self.assertTrue(self.model.paddle_fc1.bias.place.is_cpu_place())
self.assertTrue(self.model.paddle_conv2d1.weight.place.is_cpu_place())
self.assertTrue(self.model.paddle_conv2d1.bias.place.is_cpu_place())
self.assertTrue(self.model.paddle_tensor.place.is_cpu_place())
elif device.startswith("cuda"):
self.assertTrue(self.model.paddle_fc1.weight.place.is_gpu_place())
self.assertTrue(self.model.paddle_fc1.bias.place.is_gpu_place())
self.assertTrue(self.model.paddle_conv2d1.weight.place.is_gpu_place())
self.assertTrue(self.model.paddle_conv2d1.bias.place.is_gpu_place())
self.assertTrue(self.model.paddle_tensor.place.is_gpu_place())

self.assertEqual(self.model.paddle_fc1.weight.place.gpu_device_id(), self.paddle_model.paddle_fc1.weight.place.gpu_device_id())
self.assertEqual(self.model.paddle_fc1.bias.place.gpu_device_id(), self.paddle_model.paddle_fc1.bias.place.gpu_device_id())
self.assertEqual(self.model.paddle_conv2d1.weight.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.weight.place.gpu_device_id())
self.assertEqual(self.model.paddle_conv2d1.bias.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.bias.place.gpu_device_id())
self.assertEqual(self.model.paddle_tensor.place.gpu_device_id(), self.paddle_model.paddle_tensor.place.gpu_device_id())
else:
raise NotImplementedError

def if_training_correct(self, training):

self.assertEqual(self.model.torch_fc1.training, training)
self.assertEqual(self.model.torch_softmax.training, training)
self.assertEqual(self.model.torch_conv2d1.training, training)

self.assertEqual(self.model.paddle_fc1.training, training)
self.assertEqual(self.model.paddle_softmax.training, training)
self.assertEqual(self.model.paddle_conv2d1.training, training)


############################################################################
#
# 测试在MNIST数据集上的表现
#
############################################################################

class MNISTDataset(Dataset):
def __init__(self, dataset):

self.dataset = [
(
np.array(img).astype('float32').reshape(-1),
label
) for img, label in dataset
]

def __getitem__(self, idx):
return self.dataset[idx]

def __len__(self):
return len(self.dataset)

class MixMNISTModel(MixModule):
def __init__(self):
super(MixMNISTModel, self).__init__()

self.fc1 = paddle.nn.Linear(784, 64)
self.fc2 = paddle.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)
self.fc4 = torch.nn.Linear(10, 10)

def forward(self, x):

paddle_out = self.fc1(x)
paddle_out = self.fc2(paddle_out)
torch_in = paddle2torch(paddle_out)
torch_out = self.fc3(torch_in)
torch_out = self.fc4(torch_out)

return torch_out

class TestMNIST(unittest.TestCase):

@classmethod
def setUpClass(self):

self.train_dataset = paddle.vision.datasets.MNIST(mode='train')
self.test_dataset = paddle.vision.datasets.MNIST(mode='test')
self.train_dataset = MNISTDataset(self.train_dataset)

self.lr = 0.0003
self.epochs = 20

self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True)

def setUp(self):
self.model = MixMNISTModel().to("cuda")
self.torch_loss_func = torch.nn.CrossEntropyLoss()

self.torch_opt = torch.optim.Adam(self.model.parameters(backend="torch"), self.lr)
self.paddle_opt = paddle.optimizer.Adam(parameters=self.model.parameters(backend="paddle"), learning_rate=self.lr)

def test_case1(self):

# 开始训练
for epoch in range(self.epochs):
epoch_loss, batch = 0, 0
for batch, (img, label) in enumerate(self.dataloader):

img = paddle.to_tensor(img).cuda()
torch_out = self.model(img)
label = torch.from_numpy(label.numpy()).reshape(-1)
loss = self.torch_loss_func(torch_out.cpu(), label)
epoch_loss += loss.item()

loss.backward()
self.torch_opt.step()
self.paddle_opt.step()
self.torch_opt.zero_grad()
self.paddle_opt.clear_grad()

else:
self.assertLess(epoch_loss / (batch + 1), 0.3)

# 开始测试
correct = 0
for img, label in self.test_dataset:

img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
torch_out = self.model(img)
res = torch_out.softmax(-1).argmax().item()
label = label.item()
if res == label:
correct += 1

acc = correct / len(self.test_dataset)
self.assertGreater(acc, 0.85)

############################################################################
#
# 测试在ERNIE中文数据集上的表现
#
############################################################################

+ 435
- 0
tests/modules/mix_modules/test_utils.py View File

@@ -0,0 +1,435 @@
import unittest
import os

os.environ["log_silent"] = "1"
import torch
import paddle
import jittor

from fastNLP.modules.mix_modules.utils import (
paddle2torch,
torch2paddle,
jittor2torch,
torch2jittor,
)

############################################################################
#
# 测试paddle到torch的转换
#
############################################################################

class Paddle2TorchTestCase(unittest.TestCase):

def check_torch_tensor(self, tensor, device, requires_grad):
"""
检查张量设备和梯度情况的工具函数
"""

self.assertIsInstance(tensor, torch.Tensor)
self.assertEqual(tensor.device, torch.device(device))
self.assertEqual(tensor.requires_grad, requires_grad)

def test_gradient(self):
"""
测试张量转换后的反向传播是否正确
"""

x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0, 5.0], stop_gradient=False)
y = paddle2torch(x)
z = 3 * (y ** 2)
z.sum().backward()
self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30])

def test_tensor_transfer(self):
"""
测试单个张量的设备和梯度转换是否正确
"""

paddle_tensor = paddle.rand((3, 4, 5)).cpu()
res = paddle2torch(paddle_tensor)
self.check_torch_tensor(res, "cpu", not paddle_tensor.stop_gradient)

res = paddle2torch(paddle_tensor, target_device="cuda:2", no_gradient=None)
self.check_torch_tensor(res, "cuda:2", not paddle_tensor.stop_gradient)

res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=True)
self.check_torch_tensor(res, "cuda:1", False)

res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=False)
self.check_torch_tensor(res, "cuda:1", True)

def test_list_transfer(self):
"""
测试张量列表的转换
"""

paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)]
res = paddle2torch(paddle_list)
self.assertIsInstance(res, list)
for t in res:
self.check_torch_tensor(t, "cuda:1", False)

res = paddle2torch(paddle_list, target_device="cpu", no_gradient=False)
self.assertIsInstance(res, list)
for t in res:
self.check_torch_tensor(t, "cpu", True)

def test_tensor_tuple_transfer(self):
"""
测试张量元组的转换
"""

paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)]
paddle_tuple = tuple(paddle_list)
res = paddle2torch(paddle_tuple)
self.assertIsInstance(res, tuple)
for t in res:
self.check_torch_tensor(t, "cuda:1", False)

def test_dict_transfer(self):
"""
测试包含复杂结构的字典的转换
"""

paddle_dict = {
"tensor": paddle.rand((3, 4)).cuda(0),
"list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)],
"tensor": paddle.rand((3, 4)).cuda(0)
},
"int": 2,
"string": "test string"
}
res = paddle2torch(paddle_dict)
self.assertIsInstance(res, dict)
self.check_torch_tensor(res["tensor"], "cuda:0", False)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_torch_tensor(t, "cuda:0", False)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_torch_tensor(t, "cuda:0", False)
self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", False)


############################################################################
#
# 测试torch到paddle的转换
#
############################################################################

class Torch2PaddleTestCase(unittest.TestCase):

def check_paddle_tensor(self, tensor, device, stop_gradient):
"""
检查得到的paddle张量设备和梯度情况的工具函数
"""

self.assertIsInstance(tensor, paddle.Tensor)
if device == "cpu":
self.assertTrue(tensor.place.is_cpu_place())
elif device.startswith("gpu"):
paddle_device = paddle.device._convert_to_place(device)
self.assertTrue(tensor.place.is_gpu_place())
if hasattr(tensor.place, "gpu_device_id"):
# paddle中,有两种Place
# paddle.fluid.core.Place是创建Tensor时使用的类型
# 有函数gpu_device_id获取设备
self.assertEqual(tensor.place.gpu_device_id(), paddle_device.get_device_id())
else:
# 通过_convert_to_place得到的是paddle.CUDAPlace
# 通过get_device_id获取设备
self.assertEqual(tensor.place.get_device_id(), paddle_device.get_device_id())
else:
raise NotImplementedError
self.assertEqual(tensor.stop_gradient, stop_gradient)

def test_gradient(self):
"""
测试转换后梯度的反向传播
"""

x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
y = torch2paddle(x)
z = 3 * (y ** 2)
z.sum().backward()
self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30])

def test_tensor_transfer(self):
"""
测试单个张量的转换
"""

torch_tensor = torch.rand((3, 4, 5))
res = torch2paddle(torch_tensor)
self.check_paddle_tensor(res, "cpu", True)

res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=None)
self.check_paddle_tensor(res, "gpu:2", True)

res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=True)
self.check_paddle_tensor(res, "gpu:2", True)

res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=False)
self.check_paddle_tensor(res, "gpu:2", False)

def test_tensor_list_transfer(self):
"""
测试张量列表的转换
"""

torch_list = [torch.rand(6, 4, 2) for i in range(10)]
res = torch2paddle(torch_list)
self.assertIsInstance(res, list)
for t in res:
self.check_paddle_tensor(t, "cpu", True)

res = torch2paddle(torch_list, target_device="gpu:1", no_gradient=False)
self.assertIsInstance(res, list)
for t in res:
self.check_paddle_tensor(t, "gpu:1", False)

def test_tensor_tuple_transfer(self):
"""
测试张量元组的转换
"""
torch_list = [torch.rand(6, 4, 2) for i in range(10)]
torch_tuple = tuple(torch_list)
res = torch2paddle(torch_tuple, target_device="cpu")
self.assertIsInstance(res, tuple)
for t in res:
self.check_paddle_tensor(t, "cpu", True)

def test_dict_transfer(self):
"""
测试复杂的字典结构的转换
"""

torch_dict = {
"tensor": torch.rand((3, 4)),
"list": [torch.rand(6, 4, 2) for i in range(10)],
"dict":{
"list": [torch.rand(6, 4, 2) for i in range(10)],
"tensor": torch.rand((3, 4))
},
"int": 2,
"string": "test string"
}
res = torch2paddle(torch_dict)
self.assertIsInstance(res, dict)
self.check_paddle_tensor(res["tensor"], "cpu", True)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_paddle_tensor(t, "cpu", True)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_paddle_tensor(t, "cpu", True)
self.check_paddle_tensor(res["dict"]["tensor"], "cpu", True)


############################################################################
#
# 测试jittor到torch的转换
#
############################################################################

class Jittor2TorchTestCase(unittest.TestCase):

def check_torch_tensor(self, tensor, device, requires_grad):
"""
检查得到的torch张量的工具函数
"""

self.assertIsInstance(tensor, torch.Tensor)
if device == "cpu":
self.assertFalse(tensor.is_cuda)
else:
self.assertEqual(tensor.device, torch.device(device))
self.assertEqual(tensor.requires_grad, requires_grad)

def test_var_transfer(self):
"""
测试单个Jittor Var的转换
"""

jittor_var = jittor.rand((3, 4, 5))
res = jittor2torch(jittor_var)
self.check_torch_tensor(res, "cpu", True)

res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=None)
self.check_torch_tensor(res, "cuda:2", True)

res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=True)
self.check_torch_tensor(res, "cuda:2", False)

res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=False)
self.check_torch_tensor(res, "cuda:2", True)

def test_var_list_transfer(self):
"""
测试Jittor列表的转换
"""

jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)]
res = jittor2torch(jittor_list)
self.assertIsInstance(res, list)
for t in res:
self.check_torch_tensor(t, "cpu", True)

res = jittor2torch(jittor_list, target_device="cuda:1", no_gradient=False)
self.assertIsInstance(res, list)
for t in res:
self.check_torch_tensor(t, "cuda:1", True)

def test_var_tuple_transfer(self):
"""
测试Jittor变量元组的转换
"""

jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)]
jittor_tuple = tuple(jittor_list)
res = jittor2torch(jittor_tuple, target_device="cpu")
self.assertIsInstance(res, tuple)
for t in res:
self.check_torch_tensor(t, "cpu", True)

def test_dict_transfer(self):
"""
测试字典结构的转换
"""

jittor_dict = {
"tensor": jittor.rand((3, 4)),
"list": [jittor.rand(6, 4, 2) for i in range(10)],
"dict":{
"list": [jittor.rand(6, 4, 2) for i in range(10)],
"tensor": jittor.rand((3, 4))
},
"int": 2,
"string": "test string"
}
res = jittor2torch(jittor_dict)
self.assertIsInstance(res, dict)
self.check_torch_tensor(res["tensor"], "cpu", True)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_torch_tensor(t, "cpu", True)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_torch_tensor(t, "cpu", True)
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True)


############################################################################
#
# 测试torch到jittor的转换
#
############################################################################

class Torch2JittorTestCase(unittest.TestCase):

def check_jittor_var(self, var, requires_grad):
"""
检查得到的Jittor Var梯度情况的工具函数
"""

self.assertIsInstance(var, jittor.Var)
self.assertEqual(var.requires_grad, requires_grad)

def test_gradient(self):
"""
测试反向传播的梯度
"""

x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
y = torch2jittor(x)
z = 3 * (y ** 2)
grad = jittor.grad(z, y)
self.assertListEqual(grad.tolist(), [6.0, 12.0, 18.0, 24.0, 30.0])

def test_tensor_transfer(self):
"""
测试单个张量转换为Jittor
"""

torch_tensor = torch.rand((3, 4, 5))
res = torch2jittor(torch_tensor)
self.check_jittor_var(res, False)

res = torch2jittor(torch_tensor, no_gradient=None)
self.check_jittor_var(res, False)

res = torch2jittor(torch_tensor, no_gradient=True)
self.check_jittor_var(res, False)

res = torch2jittor(torch_tensor, no_gradient=False)
self.check_jittor_var(res, True)

def test_tensor_list_transfer(self):
"""
测试张量列表的转换
"""

torch_list = [torch.rand((6, 4, 2)) for i in range(10)]
res = torch2jittor(torch_list)
self.assertIsInstance(res, list)
for t in res:
self.check_jittor_var(t, False)

res = torch2jittor(torch_list, no_gradient=False)
self.assertIsInstance(res, list)
for t in res:
self.check_jittor_var(t, True)

def test_tensor_tuple_transfer(self):
"""
测试张量元组的转换
"""

torch_list = [torch.rand((6, 4, 2)) for i in range(10)]
torch_tuple = tuple(torch_list)
res = torch2jittor(torch_tuple)
self.assertIsInstance(res, tuple)
for t in res:
self.check_jittor_var(t, False)

def test_dict_transfer(self):
"""
测试字典结构的转换
"""

torch_dict = {
"tensor": torch.rand((3, 4)),
"list": [torch.rand(6, 4, 2) for i in range(10)],
"dict":{
"list": [torch.rand(6, 4, 2) for i in range(10)],
"tensor": torch.rand((3, 4))
},
"int": 2,
"string": "test string"
}
res = torch2jittor(torch_dict)
self.assertIsInstance(res, dict)
self.check_jittor_var(res["tensor"], False)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_jittor_var(t, False)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_jittor_var(t, False)
self.check_jittor_var(res["dict"]["tensor"], False)

Loading…
Cancel
Save