diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 0cae39ac..d56dbac9 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -24,6 +24,7 @@ class _FDataSet: 对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset 中调用dataset的方法 """ + def __init__(self, dataset) -> None: self.dataset = dataset @@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader): 提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 提供的方法调节设置collate_fn的若干参数。 """ + def __init__(self, dataset, batch_size: int = 1, shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, batch_sampler: Optional["Sampler[Sequence[int]]"] = None, @@ -175,17 +177,17 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], - batch_size: int = 1, - shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, - batch_sampler: Optional["Sampler[Sequence[int]]"] = None, - num_workers: int = 0, collate_fn: Optional[Callable] = None, - pin_memory: bool = False, drop_last: bool = False, - timeout: float = 0, worker_init_fn: Optional[Callable] = None, - multiprocessing_context=None, generator=None, prefetch_factor: int = 2, - persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, - non_train_batch_size: int = 16, as_numpy: bool = False, - input_fields: Union[List, str] = None)\ - -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: + batch_size: int = 1, + shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, + batch_sampler: Optional["Sampler[Sequence[int]]"] = None, + num_workers: int = 0, collate_fn: Optional[Callable] = None, + pin_memory: bool = False, drop_last: bool = False, + timeout: float = 0, worker_init_fn: Optional[Callable] = None, + multiprocessing_context=None, generator=None, prefetch_factor: int = 2, + persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, + non_train_batch_size: int = 16, as_numpy: bool = False, + input_fields: Union[List, str, None] = None) \ + -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: """ 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 @@ -221,7 +223,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, as_numpy=as_numpy) - dl.set_input(*input_fields) + if input_fields: + dl.set_input(*input_fields) return dl elif isinstance(ds_or_db, DataBundle): @@ -233,17 +236,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, - shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=non_train_sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) - dl_bundle[name].set_input(*input_fields) + if input_fields: + dl_bundle[name].set_input(*input_fields) return dl_bundle elif isinstance(ds_or_db, Sequence): @@ -269,8 +276,9 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, as_numpy=as_numpy) ) - for dl in dl_bundle: - dl.set_input(*input_fields) + if input_fields: + for dl in dl_bundle: + dl.set_input(*input_fields) return dl_bundle elif isinstance(ds_or_db, Mapping): @@ -282,18 +290,22 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, - shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=non_train_sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) - dl_bundle[name].set_input(*input_fields) + if input_fields: + dl_bundle[name].set_input(*input_fields) return dl_bundle else: diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index dbca394b..20795166 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -1,4 +1,4 @@ -import unittest +import pytest from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataset import DataSet @@ -17,7 +17,7 @@ class RandomDataset(Dataset): return 10 -class TestPaddle(unittest.TestCase): +class TestPaddle: def test_init(self): # ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 2b1dd8a9..baa3781a 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -1,11 +1,11 @@ -import unittest +import pytest from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataset import DataSet from fastNLP.io.data_bundle import DataBundle -class TestFdl(unittest.TestCase): +class TestFdl: def test_init_v1(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index 3998ec21..8ff64d04 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -1,12 +1,12 @@ import os -import unittest +import pytest import numpy as np from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException -class TestDataSetInit(unittest.TestCase): +class TestDataSetInit: """初始化DataSet的办法有以下几种: 1) 用dict: 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) @@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase): def test_init_v1(self): # 一维list ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) - self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) - self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) - self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) + assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True + assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40 + assert ds.field_arrays["y"].content == [[5, 6], ] * 40 def test_init_v2(self): # 用dict ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) - self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) - self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) + assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True + assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40 + assert ds.field_arrays["y"].content == [[5, 6], ] * 40 def test_init_assert(self): - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): _ = DataSet([[1, 2, 3, 4]] * 10) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = DataSet(0.00001) -class TestDataSetMethods(unittest.TestCase): +class TestDataSetMethods: def test_append(self): dd = DataSet() for _ in range(3): dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) - self.assertEqual(len(dd), 3) - self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) - self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) + assert len(dd) == 3 + assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3 + assert dd.field_arrays["y"].content == [[5, 6]] * 3 def test_add_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.add_field("z", [[5, 6]] * 10) - self.assertEqual(len(dd), 10) - self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) - self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) - self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) + assert len(dd) == 10 + assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10 + assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10 + assert dd.field_arrays["z"].content == [[5, 6]] * 10 - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): dd.add_field("??", [[1, 2]] * 40) def test_delete_field(self): @@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase): dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.delete_field("x") - self.assertFalse("x" in dd.field_arrays) - self.assertTrue("y" in dd.field_arrays) + assert ("x" in dd.field_arrays) == False + assert "y" in dd.field_arrays def test_delete_instance(self): dd = DataSet() @@ -80,30 +80,30 @@ class TestDataSetMethods(unittest.TestCase): dd.add_field("x", [[1, 2, 3]] * old_length) dd.add_field("y", [[1, 2, 3, 4]] * old_length) dd.delete_instance(0) - self.assertEqual(len(dd), old_length - 1) + assert len(dd) == old_length - 1 dd.delete_instance(0) - self.assertEqual(len(dd), old_length - 2) + assert len(dd) == old_length - 2 def test_getitem(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ins_1, ins_0 = ds[0], ds[1] - self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance)) - self.assertEqual(ins_1["x"], [1, 2, 3, 4]) - self.assertEqual(ins_1["y"], [5, 6]) - self.assertEqual(ins_0["x"], [1, 2, 3, 4]) - self.assertEqual(ins_0["y"], [5, 6]) + assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True + assert ins_1["x"] == [1, 2, 3, 4] + assert ins_1["y"] == [5, 6] + assert ins_0["x"] == [1, 2, 3, 4] + assert ins_0["y"] == [5, 6] sub_ds = ds[:10] - self.assertTrue(isinstance(sub_ds, DataSet)) - self.assertEqual(len(sub_ds), 10) + assert isinstance(sub_ds, DataSet) == True + assert len(sub_ds) == 10 sub_ds_1 = ds[[10, 0, 2, 3]] - self.assertTrue(isinstance(sub_ds_1, DataSet)) - self.assertEqual(len(sub_ds_1), 4) + assert isinstance(sub_ds_1, DataSet) == True + assert len(sub_ds_1) == 4 field_array = ds['x'] - self.assertTrue(isinstance(field_array, FieldArray)) - self.assertEqual(len(field_array), 40) + assert isinstance(field_array, FieldArray) == True + assert len(field_array) == 40 def test_setitem(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) @@ -120,73 +120,73 @@ class TestDataSetMethods(unittest.TestCase): assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] def test_get_item_error(self): - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) _ = ds[40:] - with self.assertRaises(KeyError): + with pytest.raises(KeyError): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) _ = ds["kom"] def test_len_(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertEqual(len(ds), 40) + assert len(ds) == 40 ds = DataSet() - self.assertEqual(len(ds), 0) + assert len(ds) == 0 def test_add_fieldarray(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40)) - self.assertEqual(ds['z'].content, [[7, 8]]*40) + ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40)) + assert ds['z'].content == [[7, 8]] * 40 - with self.assertRaises(RuntimeError): - ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10)) + with pytest.raises(RuntimeError): + ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): ds.add_fieldarray('z', [1, 2, 4]) def test_copy_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds.copy_field('x', 'z') - self.assertEqual(ds['x'].content, ds['z'].content) + assert ds['x'].content == ds['z'].content def test_has_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertTrue(ds.has_field('x')) - self.assertFalse(ds.has_field('z')) + assert ds.has_field('x') == True + assert ds.has_field('z') == False def test_get_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): ds.get_field('z') x_array = ds.get_field('x') - self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40) + assert x_array.content == [[1, 2, 3, 4]] * 40 def test_get_all_fields(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) field_arrays = ds.get_all_fields() - self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40) - self.assertEqual(field_arrays['y'], [[5, 6]] * 40) + assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40 + assert field_arrays['y'].content == [[5, 6]] * 40 def test_get_field_names(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) field_names = ds.get_field_names() - self.assertTrue('x' in field_names) - self.assertTrue('y' in field_names) + assert 'x' in field_names + assert 'y' in field_names def test_apply(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000}) ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx') - self.assertTrue("rx" in ds.field_arrays) - self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) + assert ("rx" in ds.field_arrays) == True + assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1] ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) - self.assertEqual(ds.field_arrays["y"].content[0], 2) + assert ds.field_arrays["y"].content[0] == 2 res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") - self.assertTrue(isinstance(res, list) and len(res) > 0) - self.assertTrue(res[0], 4) + assert (isinstance(res, list) and len(res) > 0) == True + assert res[0] == 4 ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k") # expect no exception raised @@ -206,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase): def modify_inplace(instance): instance['words'] = 1 + ds.apply(modify_inplace) # with self.assertRaises(TypeError): # ds.apply(modify_inplace) @@ -230,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase): T.apply_more(func_1) # print(T['c'][0, 1, 2]) - self.assertEqual(list(T["c"].content), [2, 4, 6]) - self.assertEqual(list(T["d"].content), [1, 4, 9]) + assert list(T["c"].content) == [2, 4, 6] + assert list(T["d"].content) == [1, 4, 9] res = T.apply_field_more(func_2, "a", modify_fields=False) - self.assertEqual(list(T["c"].content), [2, 4, 6]) - self.assertEqual(list(T["d"].content), [1, 4, 9]) - self.assertEqual(list(res["c"]), [3, 6, 9]) - self.assertEqual(list(res["d"]), [1, 8, 27]) + assert list(T["c"].content) == [2, 4, 6] + assert list(T["d"].content) == [1, 4, 9] + assert list(res["c"]) == [3, 6, 9] + assert list(res["d"]) == [1, 8, 27] - with self.assertRaises(ApplyResultException) as e: + with pytest.raises(ApplyResultException) as e: T.apply_more(func_err_1) print(e) - with self.assertRaises(ApplyResultException) as e: + with pytest.raises(ApplyResultException) as e: T.apply_field_more(func_err_2, "a") print(e) def test_drop(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) - self.assertEqual(len(ds), 20) + assert len(ds) == 20 def test_contains(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertTrue("x" in ds) - self.assertTrue("y" in ds) - self.assertFalse("z" in ds) + assert ("x" in ds) == True + assert ("y" in ds) == True + assert ("z" in ds) == False def test_rename_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds.rename_field("x", "xx") - self.assertTrue("xx" in ds) - self.assertFalse("x" in ds) + assert ("xx" in ds) == True + assert ("x" in ds) == False - with self.assertRaises(KeyError): + with pytest.raises(KeyError): ds.rename_field("yyy", "oo") def test_split(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) d1, d2 = ds.split(0.1) - self.assertEqual(len(d1), len(ds)*0.9) - self.assertEqual(len(d2), len(ds)*0.1) + assert len(d2) == (len(ds) * 0.9) + assert len(d1) == (len(ds) * 0.1) def test_add_field_v2(self): ds = DataSet({"x": [3, 4]}) @@ -282,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase): def test_save_load(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds.save("./my_ds.pkl") - self.assertTrue(os.path.exists("./my_ds.pkl")) + assert os.path.exists("./my_ds.pkl") == True ds_1 = DataSet.load("./my_ds.pkl") os.remove("my_ds.pkl") def test_add_null(self): ds = DataSet() - with self.assertRaises(RuntimeError) as RE: + with pytest.raises(RuntimeError) as RE: ds.add_field('test', []) def test_concat(self): @@ -301,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase): ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]}) ds3 = ds1.concat(ds2) - self.assertEqual(len(ds3), 20) + assert len(ds3) == 20 - self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4]) - self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1]) + assert ds1[9]['x'] == [1, 2, 3, 4] + assert ds1[10]['x'] == [4, 3, 2, 1] ds2[0]['x'][0] = 100 - self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 + assert ds3[10]['x'][0] == 4 # 不改变copy后的field了 ds3[10]['x'][0] = -100 - self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 + assert ds2[0]['x'][0] == 100 # 不改变copy前的field了 # 测试inplace ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) @@ -318,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase): ds3 = ds1.concat(ds2, inplace=True) ds2[0]['x'][0] = 100 - self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 + assert ds3[10]['x'][0] == 4 # 不改变copy后的field了 ds3[10]['x'][0] = -100 - self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 + assert ds2[0]['x'][0] == 100 # 不改变copy前的field了 ds3[0]['x'][0] = 100 - self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了 + assert ds1[0]['x'][0] == 100 # 改变copy前的field了 # 测试mapping ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'}) - self.assertEqual(len(ds3), 20) + assert len(ds3) == 20 # 测试忽略掉多余的 ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) @@ -340,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase): # 测试报错 ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): ds3 = ds1.concat(ds2, field_mapping={'X': 'x'}) def test_instance_field_disappear_bug(self): @@ -348,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase): data.copy_field(field_name='raw_chars', new_field_name='chars') _data = data[:1] for field_name in ['raw_chars', 'target', 'chars']: - self.assertTrue(_data.has_field(field_name)) + assert _data.has_field(field_name) == True def test_from_pandas(self): import pandas as pd @@ -356,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase): df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet.from_pandas(df) print(ds) - self.assertEqual(ds['x'].content, [1, 2, 3]) - self.assertEqual(ds['y'].content, [4, 5, 6]) + assert ds['x'].content == [1, 2, 3] + assert ds['y'].content == [4, 5, 6] def test_to_pandas(self): ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) @@ -366,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase): def test_to_csv(self): ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds.to_csv("1.csv") - self.assertTrue(os.path.exists("1.csv")) + assert os.path.exists("1.csv") == True os.remove("1.csv") def test_add_collate_fn(self): @@ -374,27 +375,26 @@ class TestDataSetMethods(unittest.TestCase): def collate_fn(item): return item - ds.add_collate_fn(collate_fn) - self.assertEqual(len(ds.collate_fns.collators), 2) + ds.add_collate_fn(collate_fn) def test_get_collator(self): from typing import Callable ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) collate_fn = ds.get_collator() - self.assertEqual(isinstance(collate_fn, Callable), True) + assert isinstance(collate_fn, Callable) == True def test_add_seq_len(self): - ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) + ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) ds.add_seq_len('x') print(ds) def test_set_target(self): - ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) + ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) ds.set_target('x') -class TestFieldArrayInit(unittest.TestCase): +class TestFieldArrayInit: """ 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) @@ -442,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase): # list of array fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])]) - def test_init_v8(self): # 二维list val = np.array([[1, 2], [3, 4]]) @@ -450,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase): fa.append(val) -class TestFieldArray(unittest.TestCase): +class TestFieldArray: def test_main(self): fa = FieldArray("x", [1, 2, 3, 4, 5]) - self.assertEqual(len(fa), 5) + assert len(fa) == 5 fa.append(6) - self.assertEqual(len(fa), 6) + assert len(fa) == 6 - self.assertEqual(fa[-1], 6) - self.assertEqual(fa[0], 1) + assert fa[-1] == 6 + assert fa[0] == 1 fa[-1] = 60 - self.assertEqual(fa[-1], 60) + assert fa[-1] == 60 - self.assertEqual(fa.get(0), 1) - self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) - self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) + assert fa.get(0) == 1 + assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True + assert list(fa.get([0, 1, 2])) == [1, 2, 3] def test_getitem_v1(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) - self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) + assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5] ans = fa[[0, 1]] - self.assertTrue(isinstance(ans, np.ndarray)) - self.assertTrue(isinstance(ans[0], np.ndarray)) - self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) - self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) - self.assertEqual(ans.dtype, np.float64) + assert isinstance(ans, np.ndarray) == True + assert isinstance(ans[0], np.ndarray) == True + assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5] + assert ans[1].tolist() == [1, 2, 3, 4, 5] + assert ans.dtype == np.float64 def test_getitem_v2(self): x = np.random.rand(10, 5) fa = FieldArray("my_field", x) indices = [0, 1, 3, 4, 6] for a, b in zip(fa[indices], x[indices]): - self.assertListEqual(a.tolist(), b.tolist()) + assert a.tolist() == b.tolist() def test_append(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) - self.assertEqual(len(fa), 3) - self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) + assert len(fa) == 3 + assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6] def test_pop(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa.pop(0) - self.assertEqual(len(fa), 1) - self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0]) + assert len(fa) == 1 + assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0] fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5] - self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) + assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5] -class TestCase(unittest.TestCase): +class TestCase: def test_init(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6]} ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) - self.assertTrue(isinstance(ins.fields, dict)) - self.assertEqual(ins.fields, fields) + assert isinstance(ins.fields, dict) == True + assert ins.fields == fields ins = Instance(**fields) - self.assertEqual(ins.fields, fields) + assert ins.fields == fields def test_add_field(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6]} ins = Instance(**fields) ins.add_field("z", [1, 1, 1]) fields.update({"z": [1, 1, 1]}) - self.assertEqual(ins.fields, fields) + assert ins.fields == fields def test_get_item(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} ins = Instance(**fields) - self.assertEqual(ins["x"], [1, 2, 3]) - self.assertEqual(ins["y"], [4, 5, 6]) - self.assertEqual(ins["z"], [1, 1, 1]) + assert ins["x"] == [1, 2, 3] + assert ins["y"] == [4, 5, 6] + assert ins["z"] == [1, 1, 1] def test_repr(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}