| @@ -864,9 +864,13 @@ class DataSet: | |||
| results = [ins for ins in self if not func(ins)] | |||
| if len(results) != 0: | |||
| dataset = DataSet(results) | |||
| return dataset | |||
| else: | |||
| return DataSet() | |||
| dataset = DataSet() | |||
| for name in self.field_arrays.keys(): | |||
| empty_field = FieldArray(name, [None]) | |||
| empty_field.content = [] | |||
| dataset.field_arrays[name] = empty_field | |||
| return dataset | |||
| def split(self, ratio: float, shuffle=True): | |||
| r""" | |||
| @@ -47,6 +47,10 @@ class FieldArray: | |||
| """ | |||
| self.content.pop(index) | |||
| def __iter__(self): | |||
| for idx in range(len(self)): | |||
| yield self[idx] | |||
| def __getitem__(self, indices: Union[int, List[int]]): | |||
| return self.get(indices) | |||
| @@ -354,6 +354,40 @@ class DataBundle: | |||
| progress_bar=progress_bar, progress_desc=progress_desc) | |||
| return res | |||
| def add_seq_len(self, field_name: str, new_field_name='seq_len', ignore_miss_dataset: bool = True): | |||
| r""" | |||
| 将使用 :func:`len` 直接对每个 dataset 的 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 | |||
| ``new_field_name`` 这个 field。 | |||
| :param field_name: 需要处理的 field_name | |||
| :param new_field_name: 新的 field_name | |||
| :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, | |||
| 如果为 ``False`` 则会报错。 | |||
| :return: | |||
| """ | |||
| return self.apply_field(len, field_name, new_field_name=new_field_name, ignore_miss_dataset=ignore_miss_dataset) | |||
| def drop(self, func: Callable, inplace=True): | |||
| r""" | |||
| 删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, | |||
| 该 Instance 会被移除或者不会包含在返回的 DataBundle 中。 | |||
| :param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance | |||
| :param inplace: 是否在当前 DataBundle 中直接删除 instance;如果为 False,将返回一个新的 DataBundle。 | |||
| :return: DataSet | |||
| """ | |||
| if inplace: | |||
| for name, dataset in self.datasets.items(): | |||
| dataset.drop(func, inplace) | |||
| return self | |||
| else: | |||
| data_bundle = DataBundle(vocabs=self.vocabs) | |||
| for name, dataset in self.datasets.items(): | |||
| res = dataset.drop(func, inplace) | |||
| data_bundle.set_dataset(res, name) | |||
| return data_bundle | |||
| def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | |||
| """ | |||
| 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
| @@ -0,0 +1,64 @@ | |||
| import pytest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| def test_add_seq_len(): | |||
| dataset1 = DataSet({ | |||
| "x": [[0,1,2], [5,3,2,3], [5,21,5,10], [3,6,8,1]] | |||
| }) | |||
| dataset2 = DataSet({ | |||
| "x": [[0,1,2,3,4], [5,3,2,3], [5,20,45,1,98], [3,6,8,3,6,31]] | |||
| }) | |||
| dataset3 = DataSet({ | |||
| "x": [[0,1,2,7,5,2], [5,3], [0], [3,6,8]] | |||
| }) | |||
| data_bundle = DataBundle(datasets={ | |||
| "dataset1": dataset1, | |||
| "dataset2": dataset2, | |||
| "dataset3": dataset3 | |||
| }) | |||
| data_bundle.add_seq_len("x") | |||
| print(data_bundle.get_dataset("dataset1")) | |||
| for i, data in enumerate(data_bundle.get_dataset("dataset1")): | |||
| print(data["seq_len"], dataset1["x"][i]) | |||
| assert data["seq_len"] == len(dataset1["x"][i]) | |||
| for i, data in enumerate(data_bundle.get_dataset("dataset2")): | |||
| assert data["seq_len"] == len(dataset2["x"][i]) | |||
| for i, data in enumerate(data_bundle.get_dataset("dataset3")): | |||
| assert data["seq_len"] == len(dataset3["x"][i]) | |||
| @pytest.mark.parametrize("inplace", [True, False]) | |||
| def test_drop(inplace): | |||
| dataset1 = DataSet({ | |||
| "x": [0, 1, 1, 4, 2, 1, 0, 1, 1, 6, 7, 1] | |||
| }) | |||
| dataset2 = DataSet({ | |||
| "x": [0, 0, 0, 0, 0] | |||
| }) | |||
| dataset3 = DataSet({ | |||
| "x": [1, 1, 1, 1, 1, 2, 3, 4] | |||
| }) | |||
| data_bundle = DataBundle(datasets={ | |||
| "dataset1": dataset1, | |||
| "dataset2": dataset2, | |||
| "dataset3": dataset3 | |||
| }) | |||
| res = data_bundle.drop(lambda x: x["x"] == 0, inplace) | |||
| if inplace: | |||
| assert res is data_bundle | |||
| else: | |||
| assert not (res is data_bundle) | |||
| assert data_bundle.get_dataset("dataset1")["x"] == dataset1["x"] | |||
| assert data_bundle.get_dataset("dataset2")["x"] == dataset2["x"] | |||
| assert data_bundle.get_dataset("dataset3")["x"] == dataset3["x"] | |||
| dataset1_drop = [1, 1, 4, 2, 1, 1, 1, 6, 7, 1] | |||
| for i, data in enumerate(res.get_dataset("dataset1")["x"]): | |||
| assert data == dataset1_drop[i] | |||
| dataset2_drop = [] | |||
| for i, data in enumerate(res.get_dataset("dataset2")["x"]): | |||
| assert data == dataset2_drop[i] | |||
| dataset3_drop = [1, 1, 1, 1, 1, 2, 3, 4] | |||
| for i, data in enumerate(res.get_dataset("dataset3")["x"]): | |||
| assert data == dataset3_drop[i] | |||