@@ -0,0 +1,7 @@ | |||
class Demo: | |||
def __init__(self): | |||
pass | |||
def demo(self): | |||
b = 1 | |||
return b |
@@ -0,0 +1,8 @@ | |||
class Demo: | |||
def __init__(self): | |||
self.b = 1 | |||
def demo(self): | |||
b = 1 | |||
return b | |||
@@ -0,0 +1,304 @@ | |||
import time | |||
import os | |||
import pytest | |||
from subprocess import Popen, PIPE | |||
from io import StringIO | |||
import sys | |||
from fastNLP.core.utils.cache_results import cache_results | |||
from tests.helpers.common.utils import check_time_elapse | |||
from fastNLP.core import synchronize_safe_rm | |||
def get_subprocess_results(cmd): | |||
pipe = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) | |||
output, err = pipe.communicate() | |||
if output: | |||
output = output.decode('utf8') | |||
else: | |||
output = '' | |||
if err: | |||
err = err.decode('utf8') | |||
else: | |||
err = '' | |||
res = output + err | |||
return res | |||
class Capturing(list): | |||
# 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面 | |||
def __enter__(self): | |||
self._stdout = sys.stdout | |||
self._stderr = sys.stderr | |||
sys.stdout = self._stringio = StringIO() | |||
sys.stderr = self._stringioerr = StringIO() | |||
return self | |||
def __exit__(self, *args): | |||
self.append(self._stringio.getvalue() + self._stringioerr.getvalue()) | |||
del self._stringio, self._stringioerr # free up some memory | |||
sys.stdout = self._stdout | |||
sys.stderr = self._stderr | |||
class TestCacheResults: | |||
def test_cache_save(self): | |||
cache_fp = 'demo.pkl' | |||
try: | |||
@cache_results(cache_fp) | |||
def demo(): | |||
time.sleep(1) | |||
return 1 | |||
res = demo() | |||
with check_time_elapse(1, op='lt'): | |||
res = demo() | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
def test_cache_save_refresh(self): | |||
cache_fp = 'demo.pkl' | |||
try: | |||
@cache_results(cache_fp, _refresh=True) | |||
def demo(): | |||
time.sleep(1.5) | |||
return 1 | |||
res = demo() | |||
with check_time_elapse(1, op='ge'): | |||
res = demo() | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
def test_cache_no_func_change(self): | |||
cache_fp = os.path.abspath('demo.pkl') | |||
try: | |||
@cache_results(cache_fp) | |||
def demo(): | |||
time.sleep(2) | |||
return 1 | |||
with check_time_elapse(1, op='gt'): | |||
res = demo() | |||
@cache_results(cache_fp) | |||
def demo(): | |||
time.sleep(2) | |||
return 1 | |||
with check_time_elapse(1, op='lt'): | |||
res = demo() | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
def test_cache_func_change(self, capsys): | |||
cache_fp = 'demo.pkl' | |||
try: | |||
@cache_results(cache_fp) | |||
def demo(): | |||
time.sleep(2) | |||
return 1 | |||
with check_time_elapse(1, op='gt'): | |||
res = demo() | |||
@cache_results(cache_fp) | |||
def demo(): | |||
time.sleep(1) | |||
return 1 | |||
with check_time_elapse(1, op='lt'): | |||
with Capturing() as output: | |||
res = demo() | |||
assert 'is different from its last cache' in output[0] | |||
# 关闭check_hash应该不warning的 | |||
with check_time_elapse(1, op='lt'): | |||
with Capturing() as output: | |||
res = demo(_check_hash=0) | |||
assert 'is different from its last cache' not in output[0] | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
def test_cache_check_hash(self): | |||
cache_fp = 'demo.pkl' | |||
try: | |||
@cache_results(cache_fp, _check_hash=False) | |||
def demo(): | |||
time.sleep(2) | |||
return 1 | |||
with check_time_elapse(1, op='gt'): | |||
res = demo() | |||
@cache_results(cache_fp, _check_hash=False) | |||
def demo(): | |||
time.sleep(1) | |||
return 1 | |||
# 默认不会check | |||
with check_time_elapse(1, op='lt'): | |||
with Capturing() as output: | |||
res = demo() | |||
assert 'is different from its last cache' not in output[0] | |||
# check也可以 | |||
with check_time_elapse(1, op='lt'): | |||
with Capturing() as output: | |||
res = demo(_check_hash=True) | |||
assert 'is different from its last cache' in output[0] | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
# 外部 function 改变也会 导致改变 | |||
def test_refer_fun_change(self): | |||
cache_fp = 'demo.pkl' | |||
test_type = 'func_refer_fun_change' | |||
try: | |||
with check_time_elapse(3, op='gt'): | |||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||
res = get_subprocess_results(cmd) | |||
# 引用的function没有变化 | |||
with check_time_elapse(2, op='lt'): | |||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||
res = get_subprocess_results(cmd) | |||
assert 'Read cache from' in res | |||
assert 'is different from its last cache' not in res | |||
# 引用的function有变化 | |||
with check_time_elapse(2, op='lt'): | |||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||
res = get_subprocess_results(cmd) | |||
assert 'is different from its last cache' in res | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
# 外部 method 改变也会 导致改变 | |||
def test_refer_class_method_change(self): | |||
cache_fp = 'demo.pkl' | |||
test_type = 'refer_class_method_change' | |||
try: | |||
with check_time_elapse(3, op='gt'): | |||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||
res = get_subprocess_results(cmd) | |||
# 引用的class没有变化 | |||
with check_time_elapse(2, op='lt'): | |||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||
res = get_subprocess_results(cmd) | |||
assert 'Read cache from' in res | |||
assert 'is different from its last cache' not in res | |||
# 引用的class有变化 | |||
with check_time_elapse(2, op='lt'): | |||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||
res = get_subprocess_results(cmd) | |||
assert 'is different from its last cache' in res | |||
finally: | |||
synchronize_safe_rm(cache_fp) | |||
def test_duplicate_keyword(self): | |||
with pytest.raises(RuntimeError): | |||
@cache_results(None) | |||
def func_verbose(a, _verbose): | |||
pass | |||
func_verbose(0, 1) | |||
with pytest.raises(RuntimeError): | |||
@cache_results(None) | |||
def func_cache(a, _cache_fp): | |||
pass | |||
func_cache(1, 2) | |||
with pytest.raises(RuntimeError): | |||
@cache_results(None) | |||
def func_refresh(a, _refresh): | |||
pass | |||
func_refresh(1, 2) | |||
with pytest.raises(RuntimeError): | |||
@cache_results(None) | |||
def func_refresh(a, _check_hash): | |||
pass | |||
func_refresh(1, 2) | |||
def test_create_cache_dir(self): | |||
@cache_results('demo/demo.pkl') | |||
def cache(): | |||
return 1, 2 | |||
try: | |||
results = cache() | |||
assert (1, 2) == results | |||
finally: | |||
synchronize_safe_rm('demo/') | |||
def test_result_none_error(self): | |||
@cache_results('demo.pkl') | |||
def cache(): | |||
pass | |||
try: | |||
with pytest.raises(RuntimeError): | |||
results = cache() | |||
finally: | |||
synchronize_safe_rm('demo.pkl') | |||
if __name__ == '__main__': | |||
import argparse | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--test_type', type=str, default='refer_class_method_change') | |||
parser.add_argument('--turn', type=int, default=1) | |||
parser.add_argument('--cache_fp', type=str, default='demo.pkl') | |||
args = parser.parse_args() | |||
test_type = args.test_type | |||
cache_fp = args.cache_fp | |||
turn = args.turn | |||
if test_type == 'func_refer_fun_change': | |||
if turn == 0: | |||
def demo(): | |||
b = 1 | |||
return b | |||
else: | |||
def demo(): | |||
b = 2 | |||
return b | |||
@cache_results(cache_fp) | |||
def demo_refer_other_func(): | |||
time.sleep(3) | |||
b = demo() | |||
return b | |||
res = demo_refer_other_func() | |||
if test_type == 'refer_class_method_change': | |||
print(f"Turn:{turn}") | |||
if turn == 0: | |||
from helper_for_cache_results_1 import Demo | |||
else: | |||
from helper_for_cache_results_2 import Demo | |||
demo = Demo() | |||
# import pdb | |||
# pdb.set_trace() | |||
@cache_results(cache_fp) | |||
def demo_func(): | |||
time.sleep(3) | |||
b = demo.demo() | |||
return b | |||
res = demo_func() | |||
@@ -0,0 +1,91 @@ | |||
import os | |||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call | |||
from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context | |||
@rank_zero_call | |||
def write_something(): | |||
print(os.environ.get('RANK', '0')*5, flush=True) | |||
def write_other_thing(): | |||
print(os.environ.get('RANK', '0')*5, flush=True) | |||
class PaddleTest: | |||
# @x54-729 | |||
def test_rank_zero_call(self): | |||
pass | |||
def test_all_rank_run(self): | |||
pass | |||
class JittorTest: | |||
# @x54-729 | |||
def test_rank_zero_call(self): | |||
pass | |||
def test_all_rank_run(self): | |||
pass | |||
class TestTorch: | |||
@magic_argv_env_context | |||
def test_rank_zero_call(self): | |||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||
os.environ['MASTER_PORT'] = '29500' | |||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: | |||
os.environ['LOCAL_RANK'] = '0' | |||
os.environ['RANK'] = '0' | |||
os.environ['WORLD_SIZE'] = '2' | |||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | |||
with Capturing() as output: | |||
write_something() | |||
output = output[0] | |||
if os.environ['LOCAL_RANK'] == '0': | |||
assert '00000' in output and '11111' not in output | |||
else: | |||
assert '00000' not in output and '11111' not in output | |||
with Capturing() as output: | |||
rank_zero_call(write_other_thing)() | |||
output = output[0] | |||
if os.environ['LOCAL_RANK'] == '0': | |||
assert '00000' in output and '11111' not in output | |||
else: | |||
assert '00000' not in output and '11111' not in output | |||
@magic_argv_env_context | |||
def test_all_rank_run(self): | |||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||
os.environ['MASTER_PORT'] = '29500' | |||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: | |||
os.environ['LOCAL_RANK'] = '0' | |||
os.environ['RANK'] = '0' | |||
os.environ['WORLD_SIZE'] = '2' | |||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | |||
# torch.distributed.init_process_group(backend='nccl') | |||
# torch.distributed.barrier() | |||
with all_rank_call(): | |||
with Capturing(no_del=True) as output: | |||
write_something() | |||
output = output[0] | |||
if os.environ['LOCAL_RANK'] == '0': | |||
assert '00000' in output | |||
else: | |||
assert '11111' in output | |||
with all_rank_call(): | |||
with Capturing(no_del=True) as output: | |||
rank_zero_call(write_other_thing)() | |||
output = output[0] | |||
if os.environ['LOCAL_RANK'] == '0': | |||
assert '00000' in output | |||
else: | |||
assert '11111' in output |
@@ -0,0 +1,200 @@ | |||
import unittest | |||
import paddle | |||
from fastNLP.core.utils.paddle_utils import paddle_to, paddle_move_data_to_device | |||
############################################################################ | |||
# | |||
# 测试仅将单个paddle张量迁移到指定设备 | |||
# | |||
############################################################################ | |||
class PaddleToDeviceTestCase(unittest.TestCase): | |||
def test_case(self): | |||
tensor = paddle.rand((4, 5)) | |||
res = paddle_to(tensor, "gpu") | |||
self.assertTrue(res.place.is_gpu_place()) | |||
self.assertEqual(res.place.gpu_device_id(), 0) | |||
res = paddle_to(tensor, "cpu") | |||
self.assertTrue(res.place.is_cpu_place()) | |||
res = paddle_to(tensor, "gpu:2") | |||
self.assertTrue(res.place.is_gpu_place()) | |||
self.assertEqual(res.place.gpu_device_id(), 2) | |||
res = paddle_to(tensor, "gpu:1") | |||
self.assertTrue(res.place.is_gpu_place()) | |||
self.assertEqual(res.place.gpu_device_id(), 1) | |||
############################################################################ | |||
# | |||
# 测试将参数中包含的所有paddle张量迁移到指定设备 | |||
# | |||
############################################################################ | |||
class PaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||
def check_gpu(self, tensor, idx): | |||
""" | |||
检查张量是否在指定的设备上的工具函数 | |||
""" | |||
self.assertTrue(tensor.place.is_gpu_place()) | |||
self.assertEqual(tensor.place.gpu_device_id(), idx) | |||
def check_cpu(self, tensor): | |||
""" | |||
检查张量是否在cpu上的工具函数 | |||
""" | |||
self.assertTrue(tensor.place.is_cpu_place()) | |||
def test_tensor_transfer(self): | |||
""" | |||
测试单个张量的迁移 | |||
""" | |||
paddle_tensor = paddle.rand((3, 4, 5)).cpu() | |||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device=None) | |||
self.check_cpu(res) | |||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) | |||
self.check_gpu(res, 0) | |||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None) | |||
self.check_gpu(res, 1) | |||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device="cpu") | |||
self.check_gpu(res, 0) | |||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") | |||
self.check_gpu(res, 0) | |||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:1") | |||
self.check_gpu(res, 1) | |||
def test_list_transfer(self): | |||
""" | |||
测试张量列表的迁移 | |||
""" | |||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
res = paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_cpu(r) | |||
res = paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_gpu(r, 0) | |||
res = paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
def test_tensor_tuple_transfer(self): | |||
""" | |||
测试张量元组的迁移 | |||
""" | |||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
paddle_tuple = tuple(paddle_list) | |||
res = paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_cpu(r) | |||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_gpu(r, 0) | |||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
def test_dict_transfer(self): | |||
""" | |||
测试字典结构的迁移 | |||
""" | |||
paddle_dict = { | |||
"tensor": paddle.rand((3, 4)), | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"dict":{ | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"tensor": paddle.rand((3, 4)) | |||
}, | |||
"int": 2, | |||
"string": "test string" | |||
} | |||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | |||
self.assertIsInstance(res, dict) | |||
self.check_gpu(res["tensor"], 0) | |||
self.assertIsInstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_gpu(t, 0) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_gpu(t, 0) | |||
self.check_gpu(res["dict"]["tensor"], 0) | |||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device="cpu") | |||
self.assertIsInstance(res, dict) | |||
self.check_gpu(res["tensor"], 0) | |||
self.assertIsInstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_gpu(t, 0) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_gpu(t, 0) | |||
self.check_gpu(res["dict"]["tensor"], 0) | |||
res = paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | |||
self.assertIsInstance(res, dict) | |||
self.check_gpu(res["tensor"], 1) | |||
self.assertIsInstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_gpu(t, 1) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_gpu(t, 1) | |||
self.check_gpu(res["dict"]["tensor"], 1) | |||
res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | |||
self.assertIsInstance(res, dict) | |||
self.check_cpu(res["tensor"]) | |||
self.assertIsInstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_cpu(t) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_cpu(t) | |||
self.check_cpu(res["dict"]["tensor"]) |
@@ -0,0 +1,205 @@ | |||
import unittest | |||
import paddle | |||
import torch | |||
from fastNLP.core.utils.torch_paddle_utils import torch_paddle_move_data_to_device | |||
############################################################################ | |||
# | |||
# 测试将参数中包含的所有torch和paddle张量迁移到指定设备 | |||
# | |||
############################################################################ | |||
class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||
def check_gpu(self, tensor, idx): | |||
""" | |||
检查张量是否在指定显卡上的工具函数 | |||
""" | |||
if isinstance(tensor, paddle.Tensor): | |||
self.assertTrue(tensor.place.is_gpu_place()) | |||
self.assertEqual(tensor.place.gpu_device_id(), idx) | |||
elif isinstance(tensor, torch.Tensor): | |||
self.assertTrue(tensor.is_cuda) | |||
self.assertEqual(tensor.device.index, idx) | |||
def check_cpu(self, tensor): | |||
if isinstance(tensor, paddle.Tensor): | |||
self.assertTrue(tensor.place.is_cpu_place()) | |||
elif isinstance(tensor, torch.Tensor): | |||
self.assertFalse(tensor.is_cuda) | |||
def test_tensor_transfer(self): | |||
""" | |||
测试迁移单个张量 | |||
""" | |||
paddle_tensor = paddle.rand((3, 4, 5)).cpu() | |||
res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device=None) | |||
self.check_cpu(res) | |||
res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) | |||
self.check_gpu(res, 0) | |||
res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None) | |||
self.check_gpu(res, 1) | |||
res = torch_paddle_move_data_to_device(paddle_tensor, device="cuda:0", data_device="cpu") | |||
self.check_gpu(res, 0) | |||
res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") | |||
self.check_gpu(res, 0) | |||
res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="cuda:1") | |||
self.check_gpu(res, 1) | |||
torch_tensor = torch.rand(3, 4, 5) | |||
res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device=None) | |||
self.check_cpu(res) | |||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device=None) | |||
print(res.device) | |||
self.check_gpu(res, 0) | |||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:1", data_device=None) | |||
self.check_gpu(res, 1) | |||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device="cpu") | |||
self.check_gpu(res, 0) | |||
res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:0") | |||
self.check_gpu(res, 0) | |||
res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:1") | |||
self.check_gpu(res, 1) | |||
def test_list_transfer(self): | |||
""" | |||
测试迁移张量的列表 | |||
""" | |||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)] | |||
res = torch_paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
res = torch_paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_cpu(r) | |||
res = torch_paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_gpu(r, 0) | |||
res = torch_paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | |||
self.assertIsInstance(res, list) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
def test_tensor_tuple_transfer(self): | |||
""" | |||
测试迁移张量的元组 | |||
""" | |||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + [torch.rand((6, 4, 2)) for i in range(5)] | |||
paddle_tuple = tuple(paddle_list) | |||
res = torch_paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
res = torch_paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_cpu(r) | |||
res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_gpu(r, 0) | |||
res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | |||
self.assertIsInstance(res, tuple) | |||
for r in res: | |||
self.check_gpu(r, 1) | |||
def test_dict_transfer(self): | |||
""" | |||
测试迁移复杂的字典结构 | |||
""" | |||
paddle_dict = { | |||
"torch_tensor": torch.rand((3, 4)), | |||
"torch_list": [torch.rand((6, 4, 2)) for i in range(10)], | |||
"dict":{ | |||
"list": [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)], | |||
"torch_tensor": torch.rand((3, 4)), | |||
"paddle_tensor": paddle.rand((3, 4)) | |||
}, | |||
"paddle_tensor": paddle.rand((3, 4)), | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)] , | |||
"int": 2, | |||
"string": "test string" | |||
} | |||
res = torch_paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | |||
self.assertIsInstance(res, dict) | |||
self.check_gpu(res["torch_tensor"], 0) | |||
self.check_gpu(res["paddle_tensor"], 0) | |||
self.assertIsInstance(res["torch_list"], list) | |||
for t in res["torch_list"]: | |||
self.check_gpu(t, 0) | |||
self.assertIsInstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_gpu(t, 0) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_gpu(t, 0) | |||
self.check_gpu(res["dict"]["torch_tensor"], 0) | |||
self.check_gpu(res["dict"]["paddle_tensor"], 0) | |||
res = torch_paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | |||
self.assertIsInstance(res, dict) | |||
self.check_gpu(res["torch_tensor"], 1) | |||
self.check_gpu(res["paddle_tensor"], 1) | |||
self.assertIsInstance(res["torch_list"], list) | |||
for t in res["torch_list"]: | |||
self.check_gpu(t, 1) | |||
self.assertIsInstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_gpu(t, 1) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_gpu(t, 1) | |||
self.check_gpu(res["dict"]["torch_tensor"], 1) | |||
self.check_gpu(res["dict"]["paddle_tensor"], 1) | |||
res = torch_paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | |||
self.assertIsInstance(res, dict) | |||
self.check_cpu(res["torch_tensor"]) | |||
self.check_cpu(res["paddle_tensor"]) | |||
self.assertIsInstance(res["torch_list"], list) | |||
for t in res["torch_list"]: | |||
self.check_cpu(t) | |||
self.assertIsInstance(res["list"], list) | |||
for t in res["list"]: | |||
self.check_cpu(t) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for t in res["dict"]["list"]: | |||
self.check_cpu(t) | |||
self.check_cpu(res["dict"]["torch_tensor"]) | |||
self.check_cpu(res["dict"]["paddle_tensor"]) |