Browse Source

修改测试用例unittest为pytest

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
ffcf3ddcd3
4 changed files with 164 additions and 153 deletions
  1. +34
    -22
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  2. +2
    -2
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  3. +2
    -2
      tests/core/dataloaders/torch_dataloader/test_fdl.py
  4. +126
    -127
      tests/core/dataset/test_dataset.py

+ 34
- 22
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -24,6 +24,7 @@ class _FDataSet:
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset 对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset
中调用dataset的方法 中调用dataset的方法
""" """

def __init__(self, dataset) -> None: def __init__(self, dataset) -> None:
self.dataset = dataset self.dataset = dataset


@@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader):
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过
提供的方法调节设置collate_fn的若干参数。 提供的方法调节设置collate_fn的若干参数。
""" """

def __init__(self, dataset, batch_size: int = 1, def __init__(self, dataset, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
@@ -175,17 +177,17 @@ class TorchDataLoader(DataLoader):




def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str] = None)\
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str, None] = None) \
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
""" """
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象


@@ -221,7 +223,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
dl.set_input(*input_fields)
if input_fields:
dl.set_input(*input_fields)
return dl return dl


elif isinstance(ds_or_db, DataBundle): elif isinstance(ds_or_db, DataBundle):
@@ -233,17 +236,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
else: else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
dl_bundle[name].set_input(*input_fields)
if input_fields:
dl_bundle[name].set_input(*input_fields)
return dl_bundle return dl_bundle


elif isinstance(ds_or_db, Sequence): elif isinstance(ds_or_db, Sequence):
@@ -269,8 +276,9 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
) )
for dl in dl_bundle:
dl.set_input(*input_fields)
if input_fields:
for dl in dl_bundle:
dl.set_input(*input_fields)
return dl_bundle return dl_bundle


elif isinstance(ds_or_db, Mapping): elif isinstance(ds_or_db, Mapping):
@@ -282,18 +290,22 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
else: else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)


dl_bundle[name].set_input(*input_fields)
if input_fields:
dl_bundle[name].set_input(*input_fields)


return dl_bundle return dl_bundle
else: else:


+ 2
- 2
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -1,4 +1,4 @@
import unittest
import pytest


from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
@@ -17,7 +17,7 @@ class RandomDataset(Dataset):
return 10 return 10




class TestPaddle(unittest.TestCase):
class TestPaddle:


def test_init(self): def test_init(self):
# ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) # ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})


+ 2
- 2
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -1,11 +1,11 @@
import unittest
import pytest


from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle from fastNLP.io.data_bundle import DataBundle




class TestFdl(unittest.TestCase):
class TestFdl:


def test_init_v1(self): def test_init_v1(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})


+ 126
- 127
tests/core/dataset/test_dataset.py View File

@@ -1,12 +1,12 @@
import os import os
import unittest
import pytest


import numpy as np import numpy as np


from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException




class TestDataSetInit(unittest.TestCase):
class TestDataSetInit:
"""初始化DataSet的办法有以下几种: """初始化DataSet的办法有以下几种:
1) 用dict: 1) 用dict:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase):
def test_init_v1(self): def test_init_v1(self):
# 一维list # 一维list
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40


def test_init_v2(self): def test_init_v2(self):
# 用dict # 用dict
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40


def test_init_assert(self): def test_init_assert(self):
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet([[1, 2, 3, 4]] * 10) _ = DataSet([[1, 2, 3, 4]] * 10)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = DataSet(0.00001) _ = DataSet(0.00001)




class TestDataSetMethods(unittest.TestCase):
class TestDataSetMethods:
def test_append(self): def test_append(self):
dd = DataSet() dd = DataSet()
for _ in range(3): for _ in range(3):
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
self.assertEqual(len(dd), 3)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
assert len(dd) == 3
assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3
assert dd.field_arrays["y"].content == [[5, 6]] * 3


def test_add_field(self): def test_add_field(self):
dd = DataSet() dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.add_field("z", [[5, 6]] * 10) dd.add_field("z", [[5, 6]] * 10)
self.assertEqual(len(dd), 10)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
assert len(dd) == 10
assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10
assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10
assert dd.field_arrays["z"].content == [[5, 6]] * 10


with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40) dd.add_field("??", [[1, 2]] * 40)


def test_delete_field(self): def test_delete_field(self):
@@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.delete_field("x") dd.delete_field("x")
self.assertFalse("x" in dd.field_arrays)
self.assertTrue("y" in dd.field_arrays)
assert ("x" in dd.field_arrays) == False
assert "y" in dd.field_arrays


def test_delete_instance(self): def test_delete_instance(self):
dd = DataSet() dd = DataSet()
@@ -80,30 +80,30 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * old_length) dd.add_field("x", [[1, 2, 3]] * old_length)
dd.add_field("y", [[1, 2, 3, 4]] * old_length) dd.add_field("y", [[1, 2, 3, 4]] * old_length)
dd.delete_instance(0) dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 1)
assert len(dd) == old_length - 1
dd.delete_instance(0) dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 2)
assert len(dd) == old_length - 2


def test_getitem(self): def test_getitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ins_1, ins_0 = ds[0], ds[1] ins_1, ins_0 = ds[0], ds[1]
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
self.assertEqual(ins_1["x"], [1, 2, 3, 4])
self.assertEqual(ins_1["y"], [5, 6])
self.assertEqual(ins_0["x"], [1, 2, 3, 4])
self.assertEqual(ins_0["y"], [5, 6])
assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True
assert ins_1["x"] == [1, 2, 3, 4]
assert ins_1["y"] == [5, 6]
assert ins_0["x"] == [1, 2, 3, 4]
assert ins_0["y"] == [5, 6]


sub_ds = ds[:10] sub_ds = ds[:10]
self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10)
assert isinstance(sub_ds, DataSet) == True
assert len(sub_ds) == 10


sub_ds_1 = ds[[10, 0, 2, 3]] sub_ds_1 = ds[[10, 0, 2, 3]]
self.assertTrue(isinstance(sub_ds_1, DataSet))
self.assertEqual(len(sub_ds_1), 4)
assert isinstance(sub_ds_1, DataSet) == True
assert len(sub_ds_1) == 4


field_array = ds['x'] field_array = ds['x']
self.assertTrue(isinstance(field_array, FieldArray))
self.assertEqual(len(field_array), 40)
assert isinstance(field_array, FieldArray) == True
assert len(field_array) == 40


def test_setitem(self): def test_setitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
@@ -120,73 +120,73 @@ class TestDataSetMethods(unittest.TestCase):
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']


def test_get_item_error(self): def test_get_item_error(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds[40:] _ = ds[40:]


with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds["kom"] _ = ds["kom"]


def test_len_(self): def test_len_(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertEqual(len(ds), 40)
assert len(ds) == 40


ds = DataSet() ds = DataSet()
self.assertEqual(len(ds), 0)
assert len(ds) == 0


def test_add_fieldarray(self): def test_add_fieldarray(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40))
self.assertEqual(ds['z'].content, [[7, 8]]*40)
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40))
assert ds['z'].content == [[7, 8]] * 40


with self.assertRaises(RuntimeError):
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10))
with pytest.raises(RuntimeError):
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10))


with self.assertRaises(TypeError):
with pytest.raises(TypeError):
ds.add_fieldarray('z', [1, 2, 4]) ds.add_fieldarray('z', [1, 2, 4])


def test_copy_field(self): def test_copy_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.copy_field('x', 'z') ds.copy_field('x', 'z')
self.assertEqual(ds['x'].content, ds['z'].content)
assert ds['x'].content == ds['z'].content


def test_has_field(self): def test_has_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue(ds.has_field('x'))
self.assertFalse(ds.has_field('z'))
assert ds.has_field('x') == True
assert ds.has_field('z') == False


def test_get_field(self): def test_get_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.get_field('z') ds.get_field('z')
x_array = ds.get_field('x') x_array = ds.get_field('x')
self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40)
assert x_array.content == [[1, 2, 3, 4]] * 40


def test_get_all_fields(self): def test_get_all_fields(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_arrays = ds.get_all_fields() field_arrays = ds.get_all_fields()
self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40)
self.assertEqual(field_arrays['y'], [[5, 6]] * 40)
assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40
assert field_arrays['y'].content == [[5, 6]] * 40


def test_get_field_names(self): def test_get_field_names(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_names = ds.get_field_names() field_names = ds.get_field_names()
self.assertTrue('x' in field_names)
self.assertTrue('y' in field_names)
assert 'x' in field_names
assert 'y' in field_names


def test_apply(self): def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000}) ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx') ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx')
self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
assert ("rx" in ds.field_arrays) == True
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1]


ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
self.assertEqual(ds.field_arrays["y"].content[0], 2)
assert ds.field_arrays["y"].content[0] == 2


res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
self.assertTrue(isinstance(res, list) and len(res) > 0)
self.assertTrue(res[0], 4)
assert (isinstance(res, list) and len(res) > 0) == True
assert res[0] == 4


ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k") ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k")
# expect no exception raised # expect no exception raised
@@ -206,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase):


def modify_inplace(instance): def modify_inplace(instance):
instance['words'] = 1 instance['words'] = 1

ds.apply(modify_inplace) ds.apply(modify_inplace)
# with self.assertRaises(TypeError): # with self.assertRaises(TypeError):
# ds.apply(modify_inplace) # ds.apply(modify_inplace)
@@ -230,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase):


T.apply_more(func_1) T.apply_more(func_1)
# print(T['c'][0, 1, 2]) # print(T['c'][0, 1, 2])
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]


res = T.apply_field_more(func_2, "a", modify_fields=False) res = T.apply_field_more(func_2, "a", modify_fields=False)
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
self.assertEqual(list(res["c"]), [3, 6, 9])
self.assertEqual(list(res["d"]), [1, 8, 27])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]
assert list(res["c"]) == [3, 6, 9]
assert list(res["d"]) == [1, 8, 27]


with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_more(func_err_1) T.apply_more(func_err_1)
print(e) print(e)


with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_field_more(func_err_2, "a") T.apply_field_more(func_err_2, "a")
print(e) print(e)


def test_drop(self): def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
self.assertEqual(len(ds), 20)
assert len(ds) == 20


def test_contains(self): def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
self.assertTrue("y" in ds)
self.assertFalse("z" in ds)
assert ("x" in ds) == True
assert ("y" in ds) == True
assert ("z" in ds) == False


def test_rename_field(self): def test_rename_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.rename_field("x", "xx") ds.rename_field("x", "xx")
self.assertTrue("xx" in ds)
self.assertFalse("x" in ds)
assert ("xx" in ds) == True
assert ("x" in ds) == False


with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.rename_field("yyy", "oo") ds.rename_field("yyy", "oo")


def test_split(self): def test_split(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
d1, d2 = ds.split(0.1) d1, d2 = ds.split(0.1)
self.assertEqual(len(d1), len(ds)*0.9)
self.assertEqual(len(d2), len(ds)*0.1)
assert len(d2) == (len(ds) * 0.9)
assert len(d1) == (len(ds) * 0.1)


def test_add_field_v2(self): def test_add_field_v2(self):
ds = DataSet({"x": [3, 4]}) ds = DataSet({"x": [3, 4]})
@@ -282,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase):
def test_save_load(self): def test_save_load(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.save("./my_ds.pkl") ds.save("./my_ds.pkl")
self.assertTrue(os.path.exists("./my_ds.pkl"))
assert os.path.exists("./my_ds.pkl") == True


ds_1 = DataSet.load("./my_ds.pkl") ds_1 = DataSet.load("./my_ds.pkl")
os.remove("my_ds.pkl") os.remove("my_ds.pkl")


def test_add_null(self): def test_add_null(self):
ds = DataSet() ds = DataSet()
with self.assertRaises(RuntimeError) as RE:
with pytest.raises(RuntimeError) as RE:
ds.add_field('test', []) ds.add_field('test', [])


def test_concat(self): def test_concat(self):
@@ -301,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase):
ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]}) ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]})
ds3 = ds1.concat(ds2) ds3 = ds1.concat(ds2)


self.assertEqual(len(ds3), 20)
assert len(ds3) == 20


self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1])
assert ds1[9]['x'] == [1, 2, 3, 4]
assert ds1[10]['x'] == [4, 3, 2, 1]


ds2[0]['x'][0] = 100 ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了


ds3[10]['x'][0] = -100 ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了


# 测试inplace # 测试inplace
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@@ -318,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase):
ds3 = ds1.concat(ds2, inplace=True) ds3 = ds1.concat(ds2, inplace=True)


ds2[0]['x'][0] = 100 ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了


ds3[10]['x'][0] = -100 ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了


ds3[0]['x'][0] = 100 ds3[0]['x'][0] = 100
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了
assert ds1[0]['x'][0] == 100 # 改变copy前的field了


# 测试mapping # 测试mapping
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'}) ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
self.assertEqual(len(ds3), 20)
assert len(ds3) == 20


# 测试忽略掉多余的 # 测试忽略掉多余的
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@@ -340,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase):
# 测试报错 # 测试报错
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds3 = ds1.concat(ds2, field_mapping={'X': 'x'}) ds3 = ds1.concat(ds2, field_mapping={'X': 'x'})


def test_instance_field_disappear_bug(self): def test_instance_field_disappear_bug(self):
@@ -348,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase):
data.copy_field(field_name='raw_chars', new_field_name='chars') data.copy_field(field_name='raw_chars', new_field_name='chars')
_data = data[:1] _data = data[:1]
for field_name in ['raw_chars', 'target', 'chars']: for field_name in ['raw_chars', 'target', 'chars']:
self.assertTrue(_data.has_field(field_name))
assert _data.has_field(field_name) == True


def test_from_pandas(self): def test_from_pandas(self):
import pandas as pd import pandas as pd
@@ -356,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase):
df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]}) df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds = DataSet.from_pandas(df) ds = DataSet.from_pandas(df)
print(ds) print(ds)
self.assertEqual(ds['x'].content, [1, 2, 3])
self.assertEqual(ds['y'].content, [4, 5, 6])
assert ds['x'].content == [1, 2, 3]
assert ds['y'].content == [4, 5, 6]


def test_to_pandas(self): def test_to_pandas(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
@@ -366,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase):
def test_to_csv(self): def test_to_csv(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds.to_csv("1.csv") ds.to_csv("1.csv")
self.assertTrue(os.path.exists("1.csv"))
assert os.path.exists("1.csv") == True
os.remove("1.csv") os.remove("1.csv")


def test_add_collate_fn(self): def test_add_collate_fn(self):
@@ -374,27 +375,26 @@ class TestDataSetMethods(unittest.TestCase):


def collate_fn(item): def collate_fn(item):
return item return item
ds.add_collate_fn(collate_fn)


self.assertEqual(len(ds.collate_fns.collators), 2)
ds.add_collate_fn(collate_fn)


def test_get_collator(self): def test_get_collator(self):
from typing import Callable from typing import Callable
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
collate_fn = ds.get_collator() collate_fn = ds.get_collator()
self.assertEqual(isinstance(collate_fn, Callable), True)
assert isinstance(collate_fn, Callable) == True


def test_add_seq_len(self): def test_add_seq_len(self):
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
ds.add_seq_len('x') ds.add_seq_len('x')
print(ds) print(ds)


def test_set_target(self): def test_set_target(self):
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
ds.set_target('x') ds.set_target('x')




class TestFieldArrayInit(unittest.TestCase):
class TestFieldArrayInit:
""" """
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -442,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase):
# list of array # list of array
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])]) fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])])



def test_init_v8(self): def test_init_v8(self):
# 二维list # 二维list
val = np.array([[1, 2], [3, 4]]) val = np.array([[1, 2], [3, 4]])
@@ -450,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase):
fa.append(val) fa.append(val)




class TestFieldArray(unittest.TestCase):
class TestFieldArray:
def test_main(self): def test_main(self):
fa = FieldArray("x", [1, 2, 3, 4, 5]) fa = FieldArray("x", [1, 2, 3, 4, 5])
self.assertEqual(len(fa), 5)
assert len(fa) == 5
fa.append(6) fa.append(6)
self.assertEqual(len(fa), 6)
assert len(fa) == 6


self.assertEqual(fa[-1], 6)
self.assertEqual(fa[0], 1)
assert fa[-1] == 6
assert fa[0] == 1
fa[-1] = 60 fa[-1] = 60
self.assertEqual(fa[-1], 60)
assert fa[-1] == 60


self.assertEqual(fa.get(0), 1)
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
assert fa.get(0) == 1
assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True
assert list(fa.get([0, 1, 2])) == [1, 2, 3]


def test_getitem_v1(self): def test_getitem_v1(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
ans = fa[[0, 1]] ans = fa[[0, 1]]
self.assertTrue(isinstance(ans, np.ndarray))
self.assertTrue(isinstance(ans[0], np.ndarray))
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
self.assertEqual(ans.dtype, np.float64)
assert isinstance(ans, np.ndarray) == True
assert isinstance(ans[0], np.ndarray) == True
assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5]
assert ans[1].tolist() == [1, 2, 3, 4, 5]
assert ans.dtype == np.float64


def test_getitem_v2(self): def test_getitem_v2(self):
x = np.random.rand(10, 5) x = np.random.rand(10, 5)
fa = FieldArray("my_field", x) fa = FieldArray("my_field", x)
indices = [0, 1, 3, 4, 6] indices = [0, 1, 3, 4, 6]
for a, b in zip(fa[indices], x[indices]): for a, b in zip(fa[indices], x[indices]):
self.assertListEqual(a.tolist(), b.tolist())
assert a.tolist() == b.tolist()


def test_append(self): def test_append(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
assert len(fa) == 3
assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6]


def test_pop(self): def test_pop(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.pop(0) fa.pop(0)
self.assertEqual(len(fa), 1)
self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0])
assert len(fa) == 1
assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0]
fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5] fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5]
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]




class TestCase(unittest.TestCase):
class TestCase:


def test_init(self): def test_init(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
self.assertTrue(isinstance(ins.fields, dict))
self.assertEqual(ins.fields, fields)
assert isinstance(ins.fields, dict) == True
assert ins.fields == fields


ins = Instance(**fields) ins = Instance(**fields)
self.assertEqual(ins.fields, fields)
assert ins.fields == fields


def test_add_field(self): def test_add_field(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(**fields) ins = Instance(**fields)
ins.add_field("z", [1, 1, 1]) ins.add_field("z", [1, 1, 1])
fields.update({"z": [1, 1, 1]}) fields.update({"z": [1, 1, 1]})
self.assertEqual(ins.fields, fields)
assert ins.fields == fields


def test_get_item(self): def test_get_item(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields) ins = Instance(**fields)
self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1])
assert ins["x"] == [1, 2, 3]
assert ins["y"] == [4, 5, 6]
assert ins["z"] == [1, 1, 1]


def test_repr(self): def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}


Loading…
Cancel
Save