@@ -864,9 +864,13 @@ class DataSet: | |||
results = [ins for ins in self if not func(ins)] | |||
if len(results) != 0: | |||
dataset = DataSet(results) | |||
return dataset | |||
else: | |||
return DataSet() | |||
dataset = DataSet() | |||
for name in self.field_arrays.keys(): | |||
empty_field = FieldArray(name, [None]) | |||
empty_field.content = [] | |||
dataset.field_arrays[name] = empty_field | |||
return dataset | |||
def split(self, ratio: float, shuffle=True): | |||
r""" | |||
@@ -47,6 +47,10 @@ class FieldArray: | |||
""" | |||
self.content.pop(index) | |||
def __iter__(self): | |||
for idx in range(len(self)): | |||
yield self[idx] | |||
def __getitem__(self, indices: Union[int, List[int]]): | |||
return self.get(indices) | |||
@@ -354,6 +354,40 @@ class DataBundle: | |||
progress_bar=progress_bar, progress_desc=progress_desc) | |||
return res | |||
def add_seq_len(self, field_name: str, new_field_name='seq_len', ignore_miss_dataset: bool = True): | |||
r""" | |||
将使用 :func:`len` 直接对每个 dataset 的 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 | |||
``new_field_name`` 这个 field。 | |||
:param field_name: 需要处理的 field_name | |||
:param new_field_name: 新的 field_name | |||
:param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, | |||
如果为 ``False`` 则会报错。 | |||
:return: | |||
""" | |||
return self.apply_field(len, field_name, new_field_name=new_field_name, ignore_miss_dataset=ignore_miss_dataset) | |||
def drop(self, func: Callable, inplace=True): | |||
r""" | |||
删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, | |||
该 Instance 会被移除或者不会包含在返回的 DataBundle 中。 | |||
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance | |||
:param inplace: 是否在当前 DataBundle 中直接删除 instance;如果为 False,将返回一个新的 DataBundle。 | |||
:return: DataSet | |||
""" | |||
if inplace: | |||
for name, dataset in self.datasets.items(): | |||
dataset.drop(func, inplace) | |||
return self | |||
else: | |||
data_bundle = DataBundle(vocabs=self.vocabs) | |||
for name, dataset in self.datasets.items(): | |||
res = dataset.drop(func, inplace) | |||
data_bundle.set_dataset(res, name) | |||
return data_bundle | |||
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
@@ -0,0 +1,64 @@ | |||
import pytest | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.io.data_bundle import DataBundle | |||
def test_add_seq_len(): | |||
dataset1 = DataSet({ | |||
"x": [[0,1,2], [5,3,2,3], [5,21,5,10], [3,6,8,1]] | |||
}) | |||
dataset2 = DataSet({ | |||
"x": [[0,1,2,3,4], [5,3,2,3], [5,20,45,1,98], [3,6,8,3,6,31]] | |||
}) | |||
dataset3 = DataSet({ | |||
"x": [[0,1,2,7,5,2], [5,3], [0], [3,6,8]] | |||
}) | |||
data_bundle = DataBundle(datasets={ | |||
"dataset1": dataset1, | |||
"dataset2": dataset2, | |||
"dataset3": dataset3 | |||
}) | |||
data_bundle.add_seq_len("x") | |||
print(data_bundle.get_dataset("dataset1")) | |||
for i, data in enumerate(data_bundle.get_dataset("dataset1")): | |||
print(data["seq_len"], dataset1["x"][i]) | |||
assert data["seq_len"] == len(dataset1["x"][i]) | |||
for i, data in enumerate(data_bundle.get_dataset("dataset2")): | |||
assert data["seq_len"] == len(dataset2["x"][i]) | |||
for i, data in enumerate(data_bundle.get_dataset("dataset3")): | |||
assert data["seq_len"] == len(dataset3["x"][i]) | |||
@pytest.mark.parametrize("inplace", [True, False]) | |||
def test_drop(inplace): | |||
dataset1 = DataSet({ | |||
"x": [0, 1, 1, 4, 2, 1, 0, 1, 1, 6, 7, 1] | |||
}) | |||
dataset2 = DataSet({ | |||
"x": [0, 0, 0, 0, 0] | |||
}) | |||
dataset3 = DataSet({ | |||
"x": [1, 1, 1, 1, 1, 2, 3, 4] | |||
}) | |||
data_bundle = DataBundle(datasets={ | |||
"dataset1": dataset1, | |||
"dataset2": dataset2, | |||
"dataset3": dataset3 | |||
}) | |||
res = data_bundle.drop(lambda x: x["x"] == 0, inplace) | |||
if inplace: | |||
assert res is data_bundle | |||
else: | |||
assert not (res is data_bundle) | |||
assert data_bundle.get_dataset("dataset1")["x"] == dataset1["x"] | |||
assert data_bundle.get_dataset("dataset2")["x"] == dataset2["x"] | |||
assert data_bundle.get_dataset("dataset3")["x"] == dataset3["x"] | |||
dataset1_drop = [1, 1, 4, 2, 1, 1, 1, 6, 7, 1] | |||
for i, data in enumerate(res.get_dataset("dataset1")["x"]): | |||
assert data == dataset1_drop[i] | |||
dataset2_drop = [] | |||
for i, data in enumerate(res.get_dataset("dataset2")["x"]): | |||
assert data == dataset2_drop[i] | |||
dataset3_drop = [1, 1, 1, 1, 1, 2, 3, 4] | |||
for i, data in enumerate(res.get_dataset("dataset3")["x"]): | |||
assert data == dataset3_drop[i] |