From 007678b6d9416033680fc183c5e5a077bd2bae90 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 14:37:32 +0000 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=20paddle=5Fdriver/test=5Futi?= =?UTF-8?q?ls.py=20=E4=B8=AD=20get=5Fdevice=5Ffrom=5Fvisible=20=E7=9A=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataset/test_dataset.py | 18 ----------------- .../core/drivers/paddle_driver/test_utils.py | 20 ------------------- 2 files changed, 38 deletions(-) diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index 8ff64d04..a2540ecf 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -370,29 +370,11 @@ class TestDataSetMethods: assert os.path.exists("1.csv") == True os.remove("1.csv") - def test_add_collate_fn(self): - ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) - - def collate_fn(item): - return item - - ds.add_collate_fn(collate_fn) - - def test_get_collator(self): - from typing import Callable - ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) - collate_fn = ds.get_collator() - assert isinstance(collate_fn, Callable) == True - def test_add_seq_len(self): ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) ds.add_seq_len('x') print(ds) - def test_set_target(self): - ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) - ds.set_target('x') - class TestFieldArrayInit: """ diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index 3b0fb9e0..66dc23c4 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -1,8 +1,6 @@ -import os import pytest from fastNLP.core.drivers.paddle_driver.utils import ( - get_device_from_visible, replace_batch_sampler, replace_sampler, ) @@ -14,24 +12,6 @@ if _NEED_IMPORT_PADDLE: from tests.helpers.datasets.paddle_data import PaddleNormalDataset -@pytest.mark.parametrize( - ("user_visible_devices, cuda_visible_devices, device, output_type, correct"), - ( - ("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), - ("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), - ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), - ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), - ("3,4,5,6", "3,5", 0, int, 0), - ("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), - ) -) -@pytest.mark.paddle -def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices - os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices - res = get_device_from_visible(device, output_type) - assert res == correct - @pytest.mark.paddle def test_replace_batch_sampler(): dataset = PaddleNormalDataset(10)