From 3a580edb997073016cd595f6515490c782c5b33f Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 8 Apr 2022 12:16:05 +0000 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4tests/core/utils/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/utils/__init__.py | 0 .../core/utils/helper_for_cache_results_1.py | 7 + .../core/utils/helper_for_cache_results_2.py | 8 + tests/core/utils/test_cache_results.py | 304 ++++++++++++++++++ tests/core/utils/test_distributed.py | 91 ++++++ tests/core/utils/test_paddle_utils.py | 200 ++++++++++++ tests/core/utils/test_torch_paddle_utils.py | 205 ++++++++++++ 7 files changed, 815 insertions(+) create mode 100644 tests/core/utils/__init__.py create mode 100644 tests/core/utils/helper_for_cache_results_1.py create mode 100644 tests/core/utils/helper_for_cache_results_2.py create mode 100644 tests/core/utils/test_cache_results.py create mode 100644 tests/core/utils/test_distributed.py create mode 100644 tests/core/utils/test_paddle_utils.py create mode 100644 tests/core/utils/test_torch_paddle_utils.py diff --git a/tests/core/utils/__init__.py b/tests/core/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/utils/helper_for_cache_results_1.py b/tests/core/utils/helper_for_cache_results_1.py new file mode 100644 index 00000000..bdac27d3 --- /dev/null +++ b/tests/core/utils/helper_for_cache_results_1.py @@ -0,0 +1,7 @@ +class Demo: + def __init__(self): + pass + + def demo(self): + b = 1 + return b \ No newline at end of file diff --git a/tests/core/utils/helper_for_cache_results_2.py b/tests/core/utils/helper_for_cache_results_2.py new file mode 100644 index 00000000..1cc0a720 --- /dev/null +++ b/tests/core/utils/helper_for_cache_results_2.py @@ -0,0 +1,8 @@ +class Demo: + def __init__(self): + self.b = 1 + + def demo(self): + b = 1 + return b + diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py new file mode 100644 index 00000000..64303f70 --- /dev/null +++ b/tests/core/utils/test_cache_results.py @@ -0,0 +1,304 @@ +import time +import os +import pytest +from subprocess import Popen, PIPE +from io import StringIO +import sys + +from fastNLP.core.utils.cache_results import cache_results +from tests.helpers.common.utils import check_time_elapse + +from fastNLP.core import synchronize_safe_rm + + +def get_subprocess_results(cmd): + pipe = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) + output, err = pipe.communicate() + if output: + output = output.decode('utf8') + else: + output = '' + if err: + err = err.decode('utf8') + else: + err = '' + res = output + err + return res + + +class Capturing(list): + # 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面 + def __enter__(self): + self._stdout = sys.stdout + self._stderr = sys.stderr + sys.stdout = self._stringio = StringIO() + sys.stderr = self._stringioerr = StringIO() + return self + + def __exit__(self, *args): + self.append(self._stringio.getvalue() + self._stringioerr.getvalue()) + del self._stringio, self._stringioerr # free up some memory + sys.stdout = self._stdout + sys.stderr = self._stderr + + +class TestCacheResults: + def test_cache_save(self): + cache_fp = 'demo.pkl' + try: + @cache_results(cache_fp) + def demo(): + time.sleep(1) + return 1 + + res = demo() + with check_time_elapse(1, op='lt'): + res = demo() + + finally: + synchronize_safe_rm(cache_fp) + + def test_cache_save_refresh(self): + cache_fp = 'demo.pkl' + try: + @cache_results(cache_fp, _refresh=True) + def demo(): + time.sleep(1.5) + return 1 + + res = demo() + with check_time_elapse(1, op='ge'): + res = demo() + finally: + synchronize_safe_rm(cache_fp) + + def test_cache_no_func_change(self): + cache_fp = os.path.abspath('demo.pkl') + try: + @cache_results(cache_fp) + def demo(): + time.sleep(2) + return 1 + + with check_time_elapse(1, op='gt'): + res = demo() + + @cache_results(cache_fp) + def demo(): + time.sleep(2) + return 1 + + with check_time_elapse(1, op='lt'): + res = demo() + finally: + synchronize_safe_rm('demo.pkl') + + def test_cache_func_change(self, capsys): + cache_fp = 'demo.pkl' + try: + @cache_results(cache_fp) + def demo(): + time.sleep(2) + return 1 + + with check_time_elapse(1, op='gt'): + res = demo() + + @cache_results(cache_fp) + def demo(): + time.sleep(1) + return 1 + + with check_time_elapse(1, op='lt'): + with Capturing() as output: + res = demo() + assert 'is different from its last cache' in output[0] + + # 关闭check_hash应该不warning的 + with check_time_elapse(1, op='lt'): + with Capturing() as output: + res = demo(_check_hash=0) + assert 'is different from its last cache' not in output[0] + + finally: + synchronize_safe_rm('demo.pkl') + + def test_cache_check_hash(self): + cache_fp = 'demo.pkl' + try: + @cache_results(cache_fp, _check_hash=False) + def demo(): + time.sleep(2) + return 1 + + with check_time_elapse(1, op='gt'): + res = demo() + + @cache_results(cache_fp, _check_hash=False) + def demo(): + time.sleep(1) + return 1 + + # 默认不会check + with check_time_elapse(1, op='lt'): + with Capturing() as output: + res = demo() + assert 'is different from its last cache' not in output[0] + + # check也可以 + with check_time_elapse(1, op='lt'): + with Capturing() as output: + res = demo(_check_hash=True) + assert 'is different from its last cache' in output[0] + + finally: + synchronize_safe_rm('demo.pkl') + + # 外部 function 改变也会 导致改变 + def test_refer_fun_change(self): + cache_fp = 'demo.pkl' + test_type = 'func_refer_fun_change' + try: + with check_time_elapse(3, op='gt'): + cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' + res = get_subprocess_results(cmd) + + # 引用的function没有变化 + with check_time_elapse(2, op='lt'): + cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' + res = get_subprocess_results(cmd) + assert 'Read cache from' in res + assert 'is different from its last cache' not in res + + # 引用的function有变化 + with check_time_elapse(2, op='lt'): + cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' + res = get_subprocess_results(cmd) + assert 'is different from its last cache' in res + + finally: + synchronize_safe_rm(cache_fp) + + # 外部 method 改变也会 导致改变 + def test_refer_class_method_change(self): + cache_fp = 'demo.pkl' + test_type = 'refer_class_method_change' + try: + with check_time_elapse(3, op='gt'): + cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' + res = get_subprocess_results(cmd) + + # 引用的class没有变化 + with check_time_elapse(2, op='lt'): + cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' + res = get_subprocess_results(cmd) + assert 'Read cache from' in res + assert 'is different from its last cache' not in res + + # 引用的class有变化 + with check_time_elapse(2, op='lt'): + cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' + res = get_subprocess_results(cmd) + assert 'is different from its last cache' in res + + finally: + synchronize_safe_rm(cache_fp) + + def test_duplicate_keyword(self): + with pytest.raises(RuntimeError): + @cache_results(None) + def func_verbose(a, _verbose): + pass + + func_verbose(0, 1) + with pytest.raises(RuntimeError): + @cache_results(None) + def func_cache(a, _cache_fp): + pass + + func_cache(1, 2) + with pytest.raises(RuntimeError): + @cache_results(None) + def func_refresh(a, _refresh): + pass + + func_refresh(1, 2) + + with pytest.raises(RuntimeError): + @cache_results(None) + def func_refresh(a, _check_hash): + pass + + func_refresh(1, 2) + + def test_create_cache_dir(self): + @cache_results('demo/demo.pkl') + def cache(): + return 1, 2 + + try: + results = cache() + assert (1, 2) == results + finally: + synchronize_safe_rm('demo/') + + def test_result_none_error(self): + @cache_results('demo.pkl') + def cache(): + pass + + try: + with pytest.raises(RuntimeError): + results = cache() + finally: + synchronize_safe_rm('demo.pkl') + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--test_type', type=str, default='refer_class_method_change') + parser.add_argument('--turn', type=int, default=1) + parser.add_argument('--cache_fp', type=str, default='demo.pkl') + args = parser.parse_args() + + test_type = args.test_type + cache_fp = args.cache_fp + turn = args.turn + + if test_type == 'func_refer_fun_change': + if turn == 0: + def demo(): + b = 1 + return b + else: + def demo(): + b = 2 + return b + + @cache_results(cache_fp) + def demo_refer_other_func(): + time.sleep(3) + b = demo() + return b + + res = demo_refer_other_func() + + if test_type == 'refer_class_method_change': + print(f"Turn:{turn}") + if turn == 0: + from helper_for_cache_results_1 import Demo + else: + from helper_for_cache_results_2 import Demo + + demo = Demo() + # import pdb + # pdb.set_trace() + @cache_results(cache_fp) + def demo_func(): + time.sleep(3) + b = demo.demo() + return b + + res = demo_func() + diff --git a/tests/core/utils/test_distributed.py b/tests/core/utils/test_distributed.py new file mode 100644 index 00000000..017f412d --- /dev/null +++ b/tests/core/utils/test_distributed.py @@ -0,0 +1,91 @@ +import os + +from fastNLP.envs.distributed import rank_zero_call, all_rank_call +from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context + + +@rank_zero_call +def write_something(): + print(os.environ.get('RANK', '0')*5, flush=True) + + +def write_other_thing(): + print(os.environ.get('RANK', '0')*5, flush=True) + + +class PaddleTest: + # @x54-729 + def test_rank_zero_call(self): + pass + + def test_all_rank_run(self): + pass + + +class JittorTest: + # @x54-729 + def test_rank_zero_call(self): + pass + + def test_all_rank_run(self): + pass + + +class TestTorch: + @magic_argv_env_context + def test_rank_zero_call(self): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: + os.environ['LOCAL_RANK'] = '0' + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '2' + re_run_current_cmd_for_torch(1, output_from_new_proc='all') + with Capturing() as output: + write_something() + output = output[0] + + if os.environ['LOCAL_RANK'] == '0': + assert '00000' in output and '11111' not in output + else: + assert '00000' not in output and '11111' not in output + + with Capturing() as output: + rank_zero_call(write_other_thing)() + + output = output[0] + if os.environ['LOCAL_RANK'] == '0': + assert '00000' in output and '11111' not in output + else: + assert '00000' not in output and '11111' not in output + + @magic_argv_env_context + def test_all_rank_run(self): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: + os.environ['LOCAL_RANK'] = '0' + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '2' + re_run_current_cmd_for_torch(1, output_from_new_proc='all') + # torch.distributed.init_process_group(backend='nccl') + # torch.distributed.barrier() + with all_rank_call(): + with Capturing(no_del=True) as output: + write_something() + output = output[0] + + if os.environ['LOCAL_RANK'] == '0': + assert '00000' in output + else: + assert '11111' in output + + with all_rank_call(): + with Capturing(no_del=True) as output: + rank_zero_call(write_other_thing)() + + output = output[0] + if os.environ['LOCAL_RANK'] == '0': + assert '00000' in output + else: + assert '11111' in output \ No newline at end of file diff --git a/tests/core/utils/test_paddle_utils.py b/tests/core/utils/test_paddle_utils.py new file mode 100644 index 00000000..344c0ed9 --- /dev/null +++ b/tests/core/utils/test_paddle_utils.py @@ -0,0 +1,200 @@ +import unittest + +import paddle + +from fastNLP.core.utils.paddle_utils import paddle_to, paddle_move_data_to_device + + +############################################################################ +# +# 测试仅将单个paddle张量迁移到指定设备 +# +############################################################################ + +class PaddleToDeviceTestCase(unittest.TestCase): + def test_case(self): + tensor = paddle.rand((4, 5)) + + res = paddle_to(tensor, "gpu") + self.assertTrue(res.place.is_gpu_place()) + self.assertEqual(res.place.gpu_device_id(), 0) + res = paddle_to(tensor, "cpu") + self.assertTrue(res.place.is_cpu_place()) + res = paddle_to(tensor, "gpu:2") + self.assertTrue(res.place.is_gpu_place()) + self.assertEqual(res.place.gpu_device_id(), 2) + res = paddle_to(tensor, "gpu:1") + self.assertTrue(res.place.is_gpu_place()) + self.assertEqual(res.place.gpu_device_id(), 1) + +############################################################################ +# +# 测试将参数中包含的所有paddle张量迁移到指定设备 +# +############################################################################ + +class PaddleMoveDataToDeviceTestCase(unittest.TestCase): + + def check_gpu(self, tensor, idx): + """ + 检查张量是否在指定的设备上的工具函数 + """ + + self.assertTrue(tensor.place.is_gpu_place()) + self.assertEqual(tensor.place.gpu_device_id(), idx) + + def check_cpu(self, tensor): + """ + 检查张量是否在cpu上的工具函数 + """ + + self.assertTrue(tensor.place.is_cpu_place()) + + def test_tensor_transfer(self): + """ + 测试单个张量的迁移 + """ + + paddle_tensor = paddle.rand((3, 4, 5)).cpu() + res = paddle_move_data_to_device(paddle_tensor, device=None, data_device=None) + self.check_cpu(res) + + res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) + self.check_gpu(res, 0) + + res = paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None) + self.check_gpu(res, 1) + + res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device="cpu") + self.check_gpu(res, 0) + + res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") + self.check_gpu(res, 0) + + res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:1") + self.check_gpu(res, 1) + + def test_list_transfer(self): + """ + 测试张量列表的迁移 + """ + + paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + res = paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") + self.assertIsInstance(res, list) + for r in res: + self.check_gpu(r, 1) + + res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") + self.assertIsInstance(res, list) + for r in res: + self.check_cpu(r) + + res = paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) + self.assertIsInstance(res, list) + for r in res: + self.check_gpu(r, 0) + + res = paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") + self.assertIsInstance(res, list) + for r in res: + self.check_gpu(r, 1) + + def test_tensor_tuple_transfer(self): + """ + 测试张量元组的迁移 + """ + + paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + paddle_tuple = tuple(paddle_list) + res = paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") + self.assertIsInstance(res, tuple) + for r in res: + self.check_gpu(r, 1) + + res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") + self.assertIsInstance(res, tuple) + for r in res: + self.check_cpu(r) + + res = paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) + self.assertIsInstance(res, tuple) + for r in res: + self.check_gpu(r, 0) + + res = paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") + self.assertIsInstance(res, tuple) + for r in res: + self.check_gpu(r, 1) + + def test_dict_transfer(self): + """ + 测试字典结构的迁移 + """ + + paddle_dict = { + "tensor": paddle.rand((3, 4)), + "list": [paddle.rand((6, 4, 2)) for i in range(10)], + "dict":{ + "list": [paddle.rand((6, 4, 2)) for i in range(10)], + "tensor": paddle.rand((3, 4)) + }, + "int": 2, + "string": "test string" + } + + res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) + self.assertIsInstance(res, dict) + self.check_gpu(res["tensor"], 0) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_gpu(t, 0) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_gpu(t, 0) + self.check_gpu(res["dict"]["tensor"], 0) + + res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device="cpu") + self.assertIsInstance(res, dict) + self.check_gpu(res["tensor"], 0) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_gpu(t, 0) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_gpu(t, 0) + self.check_gpu(res["dict"]["tensor"], 0) + + res = paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") + self.assertIsInstance(res, dict) + self.check_gpu(res["tensor"], 1) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_gpu(t, 1) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_gpu(t, 1) + self.check_gpu(res["dict"]["tensor"], 1) + + res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") + self.assertIsInstance(res, dict) + self.check_cpu(res["tensor"]) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_cpu(t) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_cpu(t) + self.check_cpu(res["dict"]["tensor"]) diff --git a/tests/core/utils/test_torch_paddle_utils.py b/tests/core/utils/test_torch_paddle_utils.py new file mode 100644 index 00000000..d5c61e4f --- /dev/null +++ b/tests/core/utils/test_torch_paddle_utils.py @@ -0,0 +1,205 @@ +import unittest + +import paddle +import torch + +from fastNLP.core.utils.torch_paddle_utils import torch_paddle_move_data_to_device + +############################################################################ +# +# 测试将参数中包含的所有torch和paddle张量迁移到指定设备 +# +############################################################################ + +class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): + + def check_gpu(self, tensor, idx): + """ + 检查张量是否在指定显卡上的工具函数 + """ + + if isinstance(tensor, paddle.Tensor): + self.assertTrue(tensor.place.is_gpu_place()) + self.assertEqual(tensor.place.gpu_device_id(), idx) + elif isinstance(tensor, torch.Tensor): + self.assertTrue(tensor.is_cuda) + self.assertEqual(tensor.device.index, idx) + + def check_cpu(self, tensor): + if isinstance(tensor, paddle.Tensor): + self.assertTrue(tensor.place.is_cpu_place()) + elif isinstance(tensor, torch.Tensor): + self.assertFalse(tensor.is_cuda) + + def test_tensor_transfer(self): + """ + 测试迁移单个张量 + """ + + paddle_tensor = paddle.rand((3, 4, 5)).cpu() + res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device=None) + self.check_cpu(res) + + res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) + self.check_gpu(res, 0) + + res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None) + self.check_gpu(res, 1) + + res = torch_paddle_move_data_to_device(paddle_tensor, device="cuda:0", data_device="cpu") + self.check_gpu(res, 0) + + res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") + self.check_gpu(res, 0) + + res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="cuda:1") + self.check_gpu(res, 1) + + torch_tensor = torch.rand(3, 4, 5) + res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device=None) + self.check_cpu(res) + + res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device=None) + print(res.device) + self.check_gpu(res, 0) + + res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:1", data_device=None) + self.check_gpu(res, 1) + + res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device="cpu") + self.check_gpu(res, 0) + + res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:0") + self.check_gpu(res, 0) + + res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:1") + self.check_gpu(res, 1) + + def test_list_transfer(self): + """ + 测试迁移张量的列表 + """ + + paddle_list = [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)] + res = torch_paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") + self.assertIsInstance(res, list) + for r in res: + self.check_gpu(r, 1) + + res = torch_paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") + self.assertIsInstance(res, list) + for r in res: + self.check_cpu(r) + + res = torch_paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) + self.assertIsInstance(res, list) + for r in res: + self.check_gpu(r, 0) + + res = torch_paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") + self.assertIsInstance(res, list) + for r in res: + self.check_gpu(r, 1) + + def test_tensor_tuple_transfer(self): + """ + 测试迁移张量的元组 + """ + + paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + [torch.rand((6, 4, 2)) for i in range(5)] + paddle_tuple = tuple(paddle_list) + res = torch_paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") + self.assertIsInstance(res, tuple) + for r in res: + self.check_gpu(r, 1) + + res = torch_paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") + self.assertIsInstance(res, tuple) + for r in res: + self.check_cpu(r) + + res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) + self.assertIsInstance(res, tuple) + for r in res: + self.check_gpu(r, 0) + + res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") + self.assertIsInstance(res, tuple) + for r in res: + self.check_gpu(r, 1) + + def test_dict_transfer(self): + """ + 测试迁移复杂的字典结构 + """ + + paddle_dict = { + "torch_tensor": torch.rand((3, 4)), + "torch_list": [torch.rand((6, 4, 2)) for i in range(10)], + "dict":{ + "list": [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)], + "torch_tensor": torch.rand((3, 4)), + "paddle_tensor": paddle.rand((3, 4)) + }, + "paddle_tensor": paddle.rand((3, 4)), + "list": [paddle.rand((6, 4, 2)) for i in range(10)] , + "int": 2, + "string": "test string" + } + + res = torch_paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) + self.assertIsInstance(res, dict) + self.check_gpu(res["torch_tensor"], 0) + self.check_gpu(res["paddle_tensor"], 0) + self.assertIsInstance(res["torch_list"], list) + for t in res["torch_list"]: + self.check_gpu(t, 0) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_gpu(t, 0) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_gpu(t, 0) + self.check_gpu(res["dict"]["torch_tensor"], 0) + self.check_gpu(res["dict"]["paddle_tensor"], 0) + + res = torch_paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") + self.assertIsInstance(res, dict) + self.check_gpu(res["torch_tensor"], 1) + self.check_gpu(res["paddle_tensor"], 1) + self.assertIsInstance(res["torch_list"], list) + for t in res["torch_list"]: + self.check_gpu(t, 1) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_gpu(t, 1) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_gpu(t, 1) + self.check_gpu(res["dict"]["torch_tensor"], 1) + self.check_gpu(res["dict"]["paddle_tensor"], 1) + + res = torch_paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") + self.assertIsInstance(res, dict) + self.check_cpu(res["torch_tensor"]) + self.check_cpu(res["paddle_tensor"]) + self.assertIsInstance(res["torch_list"], list) + for t in res["torch_list"]: + self.check_cpu(t) + self.assertIsInstance(res["list"], list) + for t in res["list"]: + self.check_cpu(t) + self.assertIsInstance(res["int"], int) + self.assertIsInstance(res["string"], str) + self.assertIsInstance(res["dict"], dict) + self.assertIsInstance(res["dict"]["list"], list) + for t in res["dict"]["list"]: + self.check_cpu(t) + self.check_cpu(res["dict"]["torch_tensor"]) + self.check_cpu(res["dict"]["paddle_tensor"])