Browse Source

修改测试用例unittest为pytest

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
ffcf3ddcd3
4 changed files with 164 additions and 153 deletions
  1. +34
    -22
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  2. +2
    -2
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  3. +2
    -2
      tests/core/dataloaders/torch_dataloader/test_fdl.py
  4. +126
    -127
      tests/core/dataset/test_dataset.py

+ 34
- 22
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -24,6 +24,7 @@ class _FDataSet:
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset
中调用dataset的方法
"""

def __init__(self, dataset) -> None:
self.dataset = dataset

@@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader):
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过
提供的方法调节设置collate_fn的若干参数。
"""

def __init__(self, dataset, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
@@ -175,17 +177,17 @@ class TorchDataLoader(DataLoader):


def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str] = None)\
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str, None] = None) \
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
"""
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象

@@ -221,7 +223,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
dl.set_input(*input_fields)
if input_fields:
dl.set_input(*input_fields)
return dl

elif isinstance(ds_or_db, DataBundle):
@@ -233,17 +236,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
dl_bundle[name].set_input(*input_fields)
if input_fields:
dl_bundle[name].set_input(*input_fields)
return dl_bundle

elif isinstance(ds_or_db, Sequence):
@@ -269,8 +276,9 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
for dl in dl_bundle:
dl.set_input(*input_fields)
if input_fields:
for dl in dl_bundle:
dl.set_input(*input_fields)
return dl_bundle

elif isinstance(ds_or_db, Mapping):
@@ -282,18 +290,22 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)

dl_bundle[name].set_input(*input_fields)
if input_fields:
dl_bundle[name].set_input(*input_fields)

return dl_bundle
else:


+ 2
- 2
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -1,4 +1,4 @@
import unittest
import pytest

from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
from fastNLP.core.dataset import DataSet
@@ -17,7 +17,7 @@ class RandomDataset(Dataset):
return 10


class TestPaddle(unittest.TestCase):
class TestPaddle:

def test_init(self):
# ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})


+ 2
- 2
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -1,11 +1,11 @@
import unittest
import pytest

from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle


class TestFdl(unittest.TestCase):
class TestFdl:

def test_init_v1(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})


+ 126
- 127
tests/core/dataset/test_dataset.py View File

@@ -1,12 +1,12 @@
import os
import unittest
import pytest

import numpy as np

from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException


class TestDataSetInit(unittest.TestCase):
class TestDataSetInit:
"""初始化DataSet的办法有以下几种:
1) 用dict:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase):
def test_init_v1(self):
# 一维list
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40

def test_init_v2(self):
# 用dict
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40

def test_init_assert(self):
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet([[1, 2, 3, 4]] * 10)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = DataSet(0.00001)


class TestDataSetMethods(unittest.TestCase):
class TestDataSetMethods:
def test_append(self):
dd = DataSet()
for _ in range(3):
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
self.assertEqual(len(dd), 3)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
assert len(dd) == 3
assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3
assert dd.field_arrays["y"].content == [[5, 6]] * 3

def test_add_field(self):
dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.add_field("z", [[5, 6]] * 10)
self.assertEqual(len(dd), 10)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
assert len(dd) == 10
assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10
assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10
assert dd.field_arrays["z"].content == [[5, 6]] * 10

with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40)

def test_delete_field(self):
@@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.delete_field("x")
self.assertFalse("x" in dd.field_arrays)
self.assertTrue("y" in dd.field_arrays)
assert ("x" in dd.field_arrays) == False
assert "y" in dd.field_arrays

def test_delete_instance(self):
dd = DataSet()
@@ -80,30 +80,30 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * old_length)
dd.add_field("y", [[1, 2, 3, 4]] * old_length)
dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 1)
assert len(dd) == old_length - 1
dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 2)
assert len(dd) == old_length - 2

def test_getitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ins_1, ins_0 = ds[0], ds[1]
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
self.assertEqual(ins_1["x"], [1, 2, 3, 4])
self.assertEqual(ins_1["y"], [5, 6])
self.assertEqual(ins_0["x"], [1, 2, 3, 4])
self.assertEqual(ins_0["y"], [5, 6])
assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True
assert ins_1["x"] == [1, 2, 3, 4]
assert ins_1["y"] == [5, 6]
assert ins_0["x"] == [1, 2, 3, 4]
assert ins_0["y"] == [5, 6]

sub_ds = ds[:10]
self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10)
assert isinstance(sub_ds, DataSet) == True
assert len(sub_ds) == 10

sub_ds_1 = ds[[10, 0, 2, 3]]
self.assertTrue(isinstance(sub_ds_1, DataSet))
self.assertEqual(len(sub_ds_1), 4)
assert isinstance(sub_ds_1, DataSet) == True
assert len(sub_ds_1) == 4

field_array = ds['x']
self.assertTrue(isinstance(field_array, FieldArray))
self.assertEqual(len(field_array), 40)
assert isinstance(field_array, FieldArray) == True
assert len(field_array) == 40

def test_setitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
@@ -120,73 +120,73 @@ class TestDataSetMethods(unittest.TestCase):
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']

def test_get_item_error(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds[40:]

with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds["kom"]

def test_len_(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertEqual(len(ds), 40)
assert len(ds) == 40

ds = DataSet()
self.assertEqual(len(ds), 0)
assert len(ds) == 0

def test_add_fieldarray(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40))
self.assertEqual(ds['z'].content, [[7, 8]]*40)
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40))
assert ds['z'].content == [[7, 8]] * 40

with self.assertRaises(RuntimeError):
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10))
with pytest.raises(RuntimeError):
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10))

with self.assertRaises(TypeError):
with pytest.raises(TypeError):
ds.add_fieldarray('z', [1, 2, 4])

def test_copy_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.copy_field('x', 'z')
self.assertEqual(ds['x'].content, ds['z'].content)
assert ds['x'].content == ds['z'].content

def test_has_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue(ds.has_field('x'))
self.assertFalse(ds.has_field('z'))
assert ds.has_field('x') == True
assert ds.has_field('z') == False

def test_get_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.get_field('z')
x_array = ds.get_field('x')
self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40)
assert x_array.content == [[1, 2, 3, 4]] * 40

def test_get_all_fields(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_arrays = ds.get_all_fields()
self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40)
self.assertEqual(field_arrays['y'], [[5, 6]] * 40)
assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40
assert field_arrays['y'].content == [[5, 6]] * 40

def test_get_field_names(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_names = ds.get_field_names()
self.assertTrue('x' in field_names)
self.assertTrue('y' in field_names)
assert 'x' in field_names
assert 'y' in field_names

def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx')
self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
assert ("rx" in ds.field_arrays) == True
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1]

ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
self.assertEqual(ds.field_arrays["y"].content[0], 2)
assert ds.field_arrays["y"].content[0] == 2

res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
self.assertTrue(isinstance(res, list) and len(res) > 0)
self.assertTrue(res[0], 4)
assert (isinstance(res, list) and len(res) > 0) == True
assert res[0] == 4

ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k")
# expect no exception raised
@@ -206,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase):

def modify_inplace(instance):
instance['words'] = 1

ds.apply(modify_inplace)
# with self.assertRaises(TypeError):
# ds.apply(modify_inplace)
@@ -230,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase):

T.apply_more(func_1)
# print(T['c'][0, 1, 2])
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]

res = T.apply_field_more(func_2, "a", modify_fields=False)
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
self.assertEqual(list(res["c"]), [3, 6, 9])
self.assertEqual(list(res["d"]), [1, 8, 27])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]
assert list(res["c"]) == [3, 6, 9]
assert list(res["d"]) == [1, 8, 27]

with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_more(func_err_1)
print(e)

with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_field_more(func_err_2, "a")
print(e)

def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
self.assertEqual(len(ds), 20)
assert len(ds) == 20

def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
self.assertTrue("y" in ds)
self.assertFalse("z" in ds)
assert ("x" in ds) == True
assert ("y" in ds) == True
assert ("z" in ds) == False

def test_rename_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.rename_field("x", "xx")
self.assertTrue("xx" in ds)
self.assertFalse("x" in ds)
assert ("xx" in ds) == True
assert ("x" in ds) == False

with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.rename_field("yyy", "oo")

def test_split(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
d1, d2 = ds.split(0.1)
self.assertEqual(len(d1), len(ds)*0.9)
self.assertEqual(len(d2), len(ds)*0.1)
assert len(d2) == (len(ds) * 0.9)
assert len(d1) == (len(ds) * 0.1)

def test_add_field_v2(self):
ds = DataSet({"x": [3, 4]})
@@ -282,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase):
def test_save_load(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.save("./my_ds.pkl")
self.assertTrue(os.path.exists("./my_ds.pkl"))
assert os.path.exists("./my_ds.pkl") == True

ds_1 = DataSet.load("./my_ds.pkl")
os.remove("my_ds.pkl")

def test_add_null(self):
ds = DataSet()
with self.assertRaises(RuntimeError) as RE:
with pytest.raises(RuntimeError) as RE:
ds.add_field('test', [])

def test_concat(self):
@@ -301,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase):
ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]})
ds3 = ds1.concat(ds2)

self.assertEqual(len(ds3), 20)
assert len(ds3) == 20

self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1])
assert ds1[9]['x'] == [1, 2, 3, 4]
assert ds1[10]['x'] == [4, 3, 2, 1]

ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了

ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了

# 测试inplace
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@@ -318,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase):
ds3 = ds1.concat(ds2, inplace=True)

ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了

ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了

ds3[0]['x'][0] = 100
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了
assert ds1[0]['x'][0] == 100 # 改变copy前的field了

# 测试mapping
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
self.assertEqual(len(ds3), 20)
assert len(ds3) == 20

# 测试忽略掉多余的
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@@ -340,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase):
# 测试报错
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds3 = ds1.concat(ds2, field_mapping={'X': 'x'})

def test_instance_field_disappear_bug(self):
@@ -348,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase):
data.copy_field(field_name='raw_chars', new_field_name='chars')
_data = data[:1]
for field_name in ['raw_chars', 'target', 'chars']:
self.assertTrue(_data.has_field(field_name))
assert _data.has_field(field_name) == True

def test_from_pandas(self):
import pandas as pd
@@ -356,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase):
df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds = DataSet.from_pandas(df)
print(ds)
self.assertEqual(ds['x'].content, [1, 2, 3])
self.assertEqual(ds['y'].content, [4, 5, 6])
assert ds['x'].content == [1, 2, 3]
assert ds['y'].content == [4, 5, 6]

def test_to_pandas(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
@@ -366,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase):
def test_to_csv(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds.to_csv("1.csv")
self.assertTrue(os.path.exists("1.csv"))
assert os.path.exists("1.csv") == True
os.remove("1.csv")

def test_add_collate_fn(self):
@@ -374,27 +375,26 @@ class TestDataSetMethods(unittest.TestCase):

def collate_fn(item):
return item
ds.add_collate_fn(collate_fn)

self.assertEqual(len(ds.collate_fns.collators), 2)
ds.add_collate_fn(collate_fn)

def test_get_collator(self):
from typing import Callable
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
collate_fn = ds.get_collator()
self.assertEqual(isinstance(collate_fn, Callable), True)
assert isinstance(collate_fn, Callable) == True

def test_add_seq_len(self):
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
ds.add_seq_len('x')
print(ds)

def test_set_target(self):
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
ds.set_target('x')


class TestFieldArrayInit(unittest.TestCase):
class TestFieldArrayInit:
"""
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -442,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase):
# list of array
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])])


def test_init_v8(self):
# 二维list
val = np.array([[1, 2], [3, 4]])
@@ -450,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase):
fa.append(val)


class TestFieldArray(unittest.TestCase):
class TestFieldArray:
def test_main(self):
fa = FieldArray("x", [1, 2, 3, 4, 5])
self.assertEqual(len(fa), 5)
assert len(fa) == 5
fa.append(6)
self.assertEqual(len(fa), 6)
assert len(fa) == 6

self.assertEqual(fa[-1], 6)
self.assertEqual(fa[0], 1)
assert fa[-1] == 6
assert fa[0] == 1
fa[-1] = 60
self.assertEqual(fa[-1], 60)
assert fa[-1] == 60

self.assertEqual(fa.get(0), 1)
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
assert fa.get(0) == 1
assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True
assert list(fa.get([0, 1, 2])) == [1, 2, 3]

def test_getitem_v1(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
ans = fa[[0, 1]]
self.assertTrue(isinstance(ans, np.ndarray))
self.assertTrue(isinstance(ans[0], np.ndarray))
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
self.assertEqual(ans.dtype, np.float64)
assert isinstance(ans, np.ndarray) == True
assert isinstance(ans[0], np.ndarray) == True
assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5]
assert ans[1].tolist() == [1, 2, 3, 4, 5]
assert ans.dtype == np.float64

def test_getitem_v2(self):
x = np.random.rand(10, 5)
fa = FieldArray("my_field", x)
indices = [0, 1, 3, 4, 6]
for a, b in zip(fa[indices], x[indices]):
self.assertListEqual(a.tolist(), b.tolist())
assert a.tolist() == b.tolist()

def test_append(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
assert len(fa) == 3
assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6]

def test_pop(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.pop(0)
self.assertEqual(len(fa), 1)
self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0])
assert len(fa) == 1
assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0]
fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5]
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]


class TestCase(unittest.TestCase):
class TestCase:

def test_init(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
self.assertTrue(isinstance(ins.fields, dict))
self.assertEqual(ins.fields, fields)
assert isinstance(ins.fields, dict) == True
assert ins.fields == fields

ins = Instance(**fields)
self.assertEqual(ins.fields, fields)
assert ins.fields == fields

def test_add_field(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(**fields)
ins.add_field("z", [1, 1, 1])
fields.update({"z": [1, 1, 1]})
self.assertEqual(ins.fields, fields)
assert ins.fields == fields

def test_get_item(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields)
self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1])
assert ins["x"] == [1, 2, 3]
assert ins["y"] == [4, 5, 6]
assert ins["z"] == [1, 1, 1]

def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}


Loading…
Cancel
Save