diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index 8ff64d04..a2540ecf 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -370,29 +370,11 @@ class TestDataSetMethods: assert os.path.exists("1.csv") == True os.remove("1.csv") - def test_add_collate_fn(self): - ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) - - def collate_fn(item): - return item - - ds.add_collate_fn(collate_fn) - - def test_get_collator(self): - from typing import Callable - ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) - collate_fn = ds.get_collator() - assert isinstance(collate_fn, Callable) == True - def test_add_seq_len(self): ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) ds.add_seq_len('x') print(ds) - def test_set_target(self): - ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) - ds.set_target('x') - class TestFieldArrayInit: """ diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index 3b0fb9e0..66dc23c4 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -1,8 +1,6 @@ -import os import pytest from fastNLP.core.drivers.paddle_driver.utils import ( - get_device_from_visible, replace_batch_sampler, replace_sampler, ) @@ -14,24 +12,6 @@ if _NEED_IMPORT_PADDLE: from tests.helpers.datasets.paddle_data import PaddleNormalDataset -@pytest.mark.parametrize( - ("user_visible_devices, cuda_visible_devices, device, output_type, correct"), - ( - ("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), - ("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), - ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), - ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), - ("3,4,5,6", "3,5", 0, int, 0), - ("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), - ) -) -@pytest.mark.paddle -def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices - os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices - res = get_device_from_visible(device, output_type) - assert res == correct - @pytest.mark.paddle def test_replace_batch_sampler(): dataset = PaddleNormalDataset(10)