@@ -864,9 +864,13 @@ class DataSet: | |||||
results = [ins for ins in self if not func(ins)] | results = [ins for ins in self if not func(ins)] | ||||
if len(results) != 0: | if len(results) != 0: | ||||
dataset = DataSet(results) | dataset = DataSet(results) | ||||
return dataset | |||||
else: | else: | ||||
return DataSet() | |||||
dataset = DataSet() | |||||
for name in self.field_arrays.keys(): | |||||
empty_field = FieldArray(name, [None]) | |||||
empty_field.content = [] | |||||
dataset.field_arrays[name] = empty_field | |||||
return dataset | |||||
def split(self, ratio: float, shuffle=True): | def split(self, ratio: float, shuffle=True): | ||||
r""" | r""" | ||||
@@ -47,6 +47,10 @@ class FieldArray: | |||||
""" | """ | ||||
self.content.pop(index) | self.content.pop(index) | ||||
def __iter__(self): | |||||
for idx in range(len(self)): | |||||
yield self[idx] | |||||
def __getitem__(self, indices: Union[int, List[int]]): | def __getitem__(self, indices: Union[int, List[int]]): | ||||
return self.get(indices) | return self.get(indices) | ||||
@@ -354,6 +354,40 @@ class DataBundle: | |||||
progress_bar=progress_bar, progress_desc=progress_desc) | progress_bar=progress_bar, progress_desc=progress_desc) | ||||
return res | return res | ||||
def add_seq_len(self, field_name: str, new_field_name='seq_len', ignore_miss_dataset: bool = True): | |||||
r""" | |||||
将使用 :func:`len` 直接对每个 dataset 的 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 | |||||
``new_field_name`` 这个 field。 | |||||
:param field_name: 需要处理的 field_name | |||||
:param new_field_name: 新的 field_name | |||||
:param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, | |||||
如果为 ``False`` 则会报错。 | |||||
:return: | |||||
""" | |||||
return self.apply_field(len, field_name, new_field_name=new_field_name, ignore_miss_dataset=ignore_miss_dataset) | |||||
def drop(self, func: Callable, inplace=True): | |||||
r""" | |||||
删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, | |||||
该 Instance 会被移除或者不会包含在返回的 DataBundle 中。 | |||||
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance | |||||
:param inplace: 是否在当前 DataBundle 中直接删除 instance;如果为 False,将返回一个新的 DataBundle。 | |||||
:return: DataSet | |||||
""" | |||||
if inplace: | |||||
for name, dataset in self.datasets.items(): | |||||
dataset.drop(func, inplace) | |||||
return self | |||||
else: | |||||
data_bundle = DataBundle(vocabs=self.vocabs) | |||||
for name, dataset in self.datasets.items(): | |||||
res = dataset.drop(func, inplace) | |||||
data_bundle.set_dataset(res, name) | |||||
return data_bundle | |||||
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | ||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -0,0 +1,64 @@ | |||||
import pytest | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
def test_add_seq_len(): | |||||
dataset1 = DataSet({ | |||||
"x": [[0,1,2], [5,3,2,3], [5,21,5,10], [3,6,8,1]] | |||||
}) | |||||
dataset2 = DataSet({ | |||||
"x": [[0,1,2,3,4], [5,3,2,3], [5,20,45,1,98], [3,6,8,3,6,31]] | |||||
}) | |||||
dataset3 = DataSet({ | |||||
"x": [[0,1,2,7,5,2], [5,3], [0], [3,6,8]] | |||||
}) | |||||
data_bundle = DataBundle(datasets={ | |||||
"dataset1": dataset1, | |||||
"dataset2": dataset2, | |||||
"dataset3": dataset3 | |||||
}) | |||||
data_bundle.add_seq_len("x") | |||||
print(data_bundle.get_dataset("dataset1")) | |||||
for i, data in enumerate(data_bundle.get_dataset("dataset1")): | |||||
print(data["seq_len"], dataset1["x"][i]) | |||||
assert data["seq_len"] == len(dataset1["x"][i]) | |||||
for i, data in enumerate(data_bundle.get_dataset("dataset2")): | |||||
assert data["seq_len"] == len(dataset2["x"][i]) | |||||
for i, data in enumerate(data_bundle.get_dataset("dataset3")): | |||||
assert data["seq_len"] == len(dataset3["x"][i]) | |||||
@pytest.mark.parametrize("inplace", [True, False]) | |||||
def test_drop(inplace): | |||||
dataset1 = DataSet({ | |||||
"x": [0, 1, 1, 4, 2, 1, 0, 1, 1, 6, 7, 1] | |||||
}) | |||||
dataset2 = DataSet({ | |||||
"x": [0, 0, 0, 0, 0] | |||||
}) | |||||
dataset3 = DataSet({ | |||||
"x": [1, 1, 1, 1, 1, 2, 3, 4] | |||||
}) | |||||
data_bundle = DataBundle(datasets={ | |||||
"dataset1": dataset1, | |||||
"dataset2": dataset2, | |||||
"dataset3": dataset3 | |||||
}) | |||||
res = data_bundle.drop(lambda x: x["x"] == 0, inplace) | |||||
if inplace: | |||||
assert res is data_bundle | |||||
else: | |||||
assert not (res is data_bundle) | |||||
assert data_bundle.get_dataset("dataset1")["x"] == dataset1["x"] | |||||
assert data_bundle.get_dataset("dataset2")["x"] == dataset2["x"] | |||||
assert data_bundle.get_dataset("dataset3")["x"] == dataset3["x"] | |||||
dataset1_drop = [1, 1, 4, 2, 1, 1, 1, 6, 7, 1] | |||||
for i, data in enumerate(res.get_dataset("dataset1")["x"]): | |||||
assert data == dataset1_drop[i] | |||||
dataset2_drop = [] | |||||
for i, data in enumerate(res.get_dataset("dataset2")["x"]): | |||||
assert data == dataset2_drop[i] | |||||
dataset3_drop = [1, 1, 1, 1, 1, 2, 3, 4] | |||||
for i, data in enumerate(res.get_dataset("dataset3")["x"]): | |||||
assert data == dataset3_drop[i] |