Browse Source

提交tests/core/utils/

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
3a580edb99
7 changed files with 815 additions and 0 deletions
  1. +0
    -0
      tests/core/utils/__init__.py
  2. +7
    -0
      tests/core/utils/helper_for_cache_results_1.py
  3. +8
    -0
      tests/core/utils/helper_for_cache_results_2.py
  4. +304
    -0
      tests/core/utils/test_cache_results.py
  5. +91
    -0
      tests/core/utils/test_distributed.py
  6. +200
    -0
      tests/core/utils/test_paddle_utils.py
  7. +205
    -0
      tests/core/utils/test_torch_paddle_utils.py

+ 0
- 0
tests/core/utils/__init__.py View File


+ 7
- 0
tests/core/utils/helper_for_cache_results_1.py View File

@@ -0,0 +1,7 @@
class Demo:
def __init__(self):
pass

def demo(self):
b = 1
return b

+ 8
- 0
tests/core/utils/helper_for_cache_results_2.py View File

@@ -0,0 +1,8 @@
class Demo:
def __init__(self):
self.b = 1

def demo(self):
b = 1
return b


+ 304
- 0
tests/core/utils/test_cache_results.py View File

@@ -0,0 +1,304 @@
import time
import os
import pytest
from subprocess import Popen, PIPE
from io import StringIO
import sys

from fastNLP.core.utils.cache_results import cache_results
from tests.helpers.common.utils import check_time_elapse

from fastNLP.core import synchronize_safe_rm


def get_subprocess_results(cmd):
pipe = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
output, err = pipe.communicate()
if output:
output = output.decode('utf8')
else:
output = ''
if err:
err = err.decode('utf8')
else:
err = ''
res = output + err
return res


class Capturing(list):
# 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
def __enter__(self):
self._stdout = sys.stdout
self._stderr = sys.stderr
sys.stdout = self._stringio = StringIO()
sys.stderr = self._stringioerr = StringIO()
return self

def __exit__(self, *args):
self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
del self._stringio, self._stringioerr # free up some memory
sys.stdout = self._stdout
sys.stderr = self._stderr


class TestCacheResults:
def test_cache_save(self):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp)
def demo():
time.sleep(1)
return 1

res = demo()
with check_time_elapse(1, op='lt'):
res = demo()

finally:
synchronize_safe_rm(cache_fp)

def test_cache_save_refresh(self):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp, _refresh=True)
def demo():
time.sleep(1.5)
return 1

res = demo()
with check_time_elapse(1, op='ge'):
res = demo()
finally:
synchronize_safe_rm(cache_fp)

def test_cache_no_func_change(self):
cache_fp = os.path.abspath('demo.pkl')
try:
@cache_results(cache_fp)
def demo():
time.sleep(2)
return 1

with check_time_elapse(1, op='gt'):
res = demo()

@cache_results(cache_fp)
def demo():
time.sleep(2)
return 1

with check_time_elapse(1, op='lt'):
res = demo()
finally:
synchronize_safe_rm('demo.pkl')

def test_cache_func_change(self, capsys):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp)
def demo():
time.sleep(2)
return 1

with check_time_elapse(1, op='gt'):
res = demo()

@cache_results(cache_fp)
def demo():
time.sleep(1)
return 1

with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo()
assert 'is different from its last cache' in output[0]

# 关闭check_hash应该不warning的
with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo(_check_hash=0)
assert 'is different from its last cache' not in output[0]

finally:
synchronize_safe_rm('demo.pkl')

def test_cache_check_hash(self):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp, _check_hash=False)
def demo():
time.sleep(2)
return 1

with check_time_elapse(1, op='gt'):
res = demo()

@cache_results(cache_fp, _check_hash=False)
def demo():
time.sleep(1)
return 1

# 默认不会check
with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo()
assert 'is different from its last cache' not in output[0]

# check也可以
with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo(_check_hash=True)
assert 'is different from its last cache' in output[0]

finally:
synchronize_safe_rm('demo.pkl')

# 外部 function 改变也会 导致改变
def test_refer_fun_change(self):
cache_fp = 'demo.pkl'
test_type = 'func_refer_fun_change'
try:
with check_time_elapse(3, op='gt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)

# 引用的function没有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert 'Read cache from' in res
assert 'is different from its last cache' not in res

# 引用的function有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert 'is different from its last cache' in res

finally:
synchronize_safe_rm(cache_fp)

# 外部 method 改变也会 导致改变
def test_refer_class_method_change(self):
cache_fp = 'demo.pkl'
test_type = 'refer_class_method_change'
try:
with check_time_elapse(3, op='gt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)

# 引用的class没有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert 'Read cache from' in res
assert 'is different from its last cache' not in res

# 引用的class有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert 'is different from its last cache' in res

finally:
synchronize_safe_rm(cache_fp)

def test_duplicate_keyword(self):
with pytest.raises(RuntimeError):
@cache_results(None)
def func_verbose(a, _verbose):
pass

func_verbose(0, 1)
with pytest.raises(RuntimeError):
@cache_results(None)
def func_cache(a, _cache_fp):
pass

func_cache(1, 2)
with pytest.raises(RuntimeError):
@cache_results(None)
def func_refresh(a, _refresh):
pass

func_refresh(1, 2)

with pytest.raises(RuntimeError):
@cache_results(None)
def func_refresh(a, _check_hash):
pass

func_refresh(1, 2)

def test_create_cache_dir(self):
@cache_results('demo/demo.pkl')
def cache():
return 1, 2

try:
results = cache()
assert (1, 2) == results
finally:
synchronize_safe_rm('demo/')

def test_result_none_error(self):
@cache_results('demo.pkl')
def cache():
pass

try:
with pytest.raises(RuntimeError):
results = cache()
finally:
synchronize_safe_rm('demo.pkl')


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--test_type', type=str, default='refer_class_method_change')
parser.add_argument('--turn', type=int, default=1)
parser.add_argument('--cache_fp', type=str, default='demo.pkl')
args = parser.parse_args()

test_type = args.test_type
cache_fp = args.cache_fp
turn = args.turn

if test_type == 'func_refer_fun_change':
if turn == 0:
def demo():
b = 1
return b
else:
def demo():
b = 2
return b

@cache_results(cache_fp)
def demo_refer_other_func():
time.sleep(3)
b = demo()
return b

res = demo_refer_other_func()

if test_type == 'refer_class_method_change':
print(f"Turn:{turn}")
if turn == 0:
from helper_for_cache_results_1 import Demo
else:
from helper_for_cache_results_2 import Demo

demo = Demo()
# import pdb
# pdb.set_trace()
@cache_results(cache_fp)
def demo_func():
time.sleep(3)
b = demo.demo()
return b

res = demo_func()


+ 91
- 0
tests/core/utils/test_distributed.py View File

@@ -0,0 +1,91 @@
import os

from fastNLP.envs.distributed import rank_zero_call, all_rank_call
from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context


@rank_zero_call
def write_something():
print(os.environ.get('RANK', '0')*5, flush=True)


def write_other_thing():
print(os.environ.get('RANK', '0')*5, flush=True)


class PaddleTest:
# @x54-729
def test_rank_zero_call(self):
pass

def test_all_rank_run(self):
pass


class JittorTest:
# @x54-729
def test_rank_zero_call(self):
pass

def test_all_rank_run(self):
pass


class TestTorch:
@magic_argv_env_context
def test_rank_zero_call(self):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
os.environ['LOCAL_RANK'] = '0'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '2'
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
with Capturing() as output:
write_something()
output = output[0]

if os.environ['LOCAL_RANK'] == '0':
assert '00000' in output and '11111' not in output
else:
assert '00000' not in output and '11111' not in output

with Capturing() as output:
rank_zero_call(write_other_thing)()

output = output[0]
if os.environ['LOCAL_RANK'] == '0':
assert '00000' in output and '11111' not in output
else:
assert '00000' not in output and '11111' not in output

@magic_argv_env_context
def test_all_rank_run(self):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
os.environ['LOCAL_RANK'] = '0'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '2'
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
# torch.distributed.init_process_group(backend='nccl')
# torch.distributed.barrier()
with all_rank_call():
with Capturing(no_del=True) as output:
write_something()
output = output[0]

if os.environ['LOCAL_RANK'] == '0':
assert '00000' in output
else:
assert '11111' in output

with all_rank_call():
with Capturing(no_del=True) as output:
rank_zero_call(write_other_thing)()

output = output[0]
if os.environ['LOCAL_RANK'] == '0':
assert '00000' in output
else:
assert '11111' in output

+ 200
- 0
tests/core/utils/test_paddle_utils.py View File

@@ -0,0 +1,200 @@
import unittest

import paddle

from fastNLP.core.utils.paddle_utils import paddle_to, paddle_move_data_to_device


############################################################################
#
# 测试仅将单个paddle张量迁移到指定设备
#
############################################################################

class PaddleToDeviceTestCase(unittest.TestCase):
def test_case(self):
tensor = paddle.rand((4, 5))

res = paddle_to(tensor, "gpu")
self.assertTrue(res.place.is_gpu_place())
self.assertEqual(res.place.gpu_device_id(), 0)
res = paddle_to(tensor, "cpu")
self.assertTrue(res.place.is_cpu_place())
res = paddle_to(tensor, "gpu:2")
self.assertTrue(res.place.is_gpu_place())
self.assertEqual(res.place.gpu_device_id(), 2)
res = paddle_to(tensor, "gpu:1")
self.assertTrue(res.place.is_gpu_place())
self.assertEqual(res.place.gpu_device_id(), 1)

############################################################################
#
# 测试将参数中包含的所有paddle张量迁移到指定设备
#
############################################################################

class PaddleMoveDataToDeviceTestCase(unittest.TestCase):

def check_gpu(self, tensor, idx):
"""
检查张量是否在指定的设备上的工具函数
"""

self.assertTrue(tensor.place.is_gpu_place())
self.assertEqual(tensor.place.gpu_device_id(), idx)

def check_cpu(self, tensor):
"""
检查张量是否在cpu上的工具函数
"""

self.assertTrue(tensor.place.is_cpu_place())

def test_tensor_transfer(self):
"""
测试单个张量的迁移
"""

paddle_tensor = paddle.rand((3, 4, 5)).cpu()
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device=None)
self.check_cpu(res)

res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None)
self.check_gpu(res, 0)

res = paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None)
self.check_gpu(res, 1)

res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device="cpu")
self.check_gpu(res, 0)

res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0")
self.check_gpu(res, 0)

res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:1")
self.check_gpu(res, 1)

def test_list_transfer(self):
"""
测试张量列表的迁移
"""

paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
res = paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1")
self.assertIsInstance(res, list)
for r in res:
self.check_gpu(r, 1)

res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1")
self.assertIsInstance(res, list)
for r in res:
self.check_cpu(r)

res = paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None)
self.assertIsInstance(res, list)
for r in res:
self.check_gpu(r, 0)

res = paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu")
self.assertIsInstance(res, list)
for r in res:
self.check_gpu(r, 1)

def test_tensor_tuple_transfer(self):
"""
测试张量元组的迁移
"""

paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
paddle_tuple = tuple(paddle_list)
res = paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1")
self.assertIsInstance(res, tuple)
for r in res:
self.check_gpu(r, 1)

res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1")
self.assertIsInstance(res, tuple)
for r in res:
self.check_cpu(r)

res = paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None)
self.assertIsInstance(res, tuple)
for r in res:
self.check_gpu(r, 0)

res = paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu")
self.assertIsInstance(res, tuple)
for r in res:
self.check_gpu(r, 1)

def test_dict_transfer(self):
"""
测试字典结构的迁移
"""

paddle_dict = {
"tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"tensor": paddle.rand((3, 4))
},
"int": 2,
"string": "test string"
}

res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None)
self.assertIsInstance(res, dict)
self.check_gpu(res["tensor"], 0)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_gpu(t, 0)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_gpu(t, 0)
self.check_gpu(res["dict"]["tensor"], 0)

res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device="cpu")
self.assertIsInstance(res, dict)
self.check_gpu(res["tensor"], 0)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_gpu(t, 0)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_gpu(t, 0)
self.check_gpu(res["dict"]["tensor"], 0)

res = paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1")
self.assertIsInstance(res, dict)
self.check_gpu(res["tensor"], 1)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_gpu(t, 1)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_gpu(t, 1)
self.check_gpu(res["dict"]["tensor"], 1)

res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0")
self.assertIsInstance(res, dict)
self.check_cpu(res["tensor"])
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_cpu(t)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_cpu(t)
self.check_cpu(res["dict"]["tensor"])

+ 205
- 0
tests/core/utils/test_torch_paddle_utils.py View File

@@ -0,0 +1,205 @@
import unittest

import paddle
import torch

from fastNLP.core.utils.torch_paddle_utils import torch_paddle_move_data_to_device

############################################################################
#
# 测试将参数中包含的所有torch和paddle张量迁移到指定设备
#
############################################################################

class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase):

def check_gpu(self, tensor, idx):
"""
检查张量是否在指定显卡上的工具函数
"""

if isinstance(tensor, paddle.Tensor):
self.assertTrue(tensor.place.is_gpu_place())
self.assertEqual(tensor.place.gpu_device_id(), idx)
elif isinstance(tensor, torch.Tensor):
self.assertTrue(tensor.is_cuda)
self.assertEqual(tensor.device.index, idx)

def check_cpu(self, tensor):
if isinstance(tensor, paddle.Tensor):
self.assertTrue(tensor.place.is_cpu_place())
elif isinstance(tensor, torch.Tensor):
self.assertFalse(tensor.is_cuda)

def test_tensor_transfer(self):
"""
测试迁移单个张量
"""

paddle_tensor = paddle.rand((3, 4, 5)).cpu()
res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device=None)
self.check_cpu(res)

res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None)
self.check_gpu(res, 0)

res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None)
self.check_gpu(res, 1)

res = torch_paddle_move_data_to_device(paddle_tensor, device="cuda:0", data_device="cpu")
self.check_gpu(res, 0)

res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0")
self.check_gpu(res, 0)

res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="cuda:1")
self.check_gpu(res, 1)

torch_tensor = torch.rand(3, 4, 5)
res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device=None)
self.check_cpu(res)

res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device=None)
print(res.device)
self.check_gpu(res, 0)

res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:1", data_device=None)
self.check_gpu(res, 1)

res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device="cpu")
self.check_gpu(res, 0)

res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:0")
self.check_gpu(res, 0)

res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:1")
self.check_gpu(res, 1)

def test_list_transfer(self):
"""
测试迁移张量的列表
"""

paddle_list = [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)]
res = torch_paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1")
self.assertIsInstance(res, list)
for r in res:
self.check_gpu(r, 1)

res = torch_paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1")
self.assertIsInstance(res, list)
for r in res:
self.check_cpu(r)

res = torch_paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None)
self.assertIsInstance(res, list)
for r in res:
self.check_gpu(r, 0)

res = torch_paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu")
self.assertIsInstance(res, list)
for r in res:
self.check_gpu(r, 1)

def test_tensor_tuple_transfer(self):
"""
测试迁移张量的元组
"""

paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + [torch.rand((6, 4, 2)) for i in range(5)]
paddle_tuple = tuple(paddle_list)
res = torch_paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1")
self.assertIsInstance(res, tuple)
for r in res:
self.check_gpu(r, 1)

res = torch_paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1")
self.assertIsInstance(res, tuple)
for r in res:
self.check_cpu(r)

res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None)
self.assertIsInstance(res, tuple)
for r in res:
self.check_gpu(r, 0)

res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu")
self.assertIsInstance(res, tuple)
for r in res:
self.check_gpu(r, 1)

def test_dict_transfer(self):
"""
测试迁移复杂的字典结构
"""

paddle_dict = {
"torch_tensor": torch.rand((3, 4)),
"torch_list": [torch.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)],
"torch_tensor": torch.rand((3, 4)),
"paddle_tensor": paddle.rand((3, 4))
},
"paddle_tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)] ,
"int": 2,
"string": "test string"
}

res = torch_paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None)
self.assertIsInstance(res, dict)
self.check_gpu(res["torch_tensor"], 0)
self.check_gpu(res["paddle_tensor"], 0)
self.assertIsInstance(res["torch_list"], list)
for t in res["torch_list"]:
self.check_gpu(t, 0)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_gpu(t, 0)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_gpu(t, 0)
self.check_gpu(res["dict"]["torch_tensor"], 0)
self.check_gpu(res["dict"]["paddle_tensor"], 0)

res = torch_paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1")
self.assertIsInstance(res, dict)
self.check_gpu(res["torch_tensor"], 1)
self.check_gpu(res["paddle_tensor"], 1)
self.assertIsInstance(res["torch_list"], list)
for t in res["torch_list"]:
self.check_gpu(t, 1)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_gpu(t, 1)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_gpu(t, 1)
self.check_gpu(res["dict"]["torch_tensor"], 1)
self.check_gpu(res["dict"]["paddle_tensor"], 1)

res = torch_paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0")
self.assertIsInstance(res, dict)
self.check_cpu(res["torch_tensor"])
self.check_cpu(res["paddle_tensor"])
self.assertIsInstance(res["torch_list"], list)
for t in res["torch_list"]:
self.check_cpu(t)
self.assertIsInstance(res["list"], list)
for t in res["list"]:
self.check_cpu(t)
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_cpu(t)
self.check_cpu(res["dict"]["torch_tensor"])
self.check_cpu(res["dict"]["paddle_tensor"])

Loading…
Cancel
Save