From 67aee2c897ed396aae9582dc8bffac72d90976a5 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 7 Oct 2022 14:51:04 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9=20DataSet.drop=EF=BC=8C?= =?UTF-8?q?=E4=BD=BF=E4=B9=8B=E5=9C=A8=20inplace=20=E4=B8=BA=20False=20?= =?UTF-8?q?=E6=97=B6=E4=B9=9F=E8=83=BD=E8=BF=94=E5=9B=9E=E7=A9=BA=20FieldA?= =?UTF-8?q?rray=20=E7=9A=84=E6=95=B0=E6=8D=AE=202.=E6=B7=BB=E5=8A=A0=20Fie?= =?UTF-8?q?ldArray=20=E7=9A=84=E8=BF=AD=E4=BB=A3=E6=96=B9=E6=B3=95=203.=20?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=20DataBunlde=20=E7=9A=84=20drop=20=E5=92=8C?= =?UTF-8?q?=20add=5Fseq=5Flen=20=E5=87=BD=E6=95=B0=E5=8F=8A=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset/dataset.py | 8 +++-- fastNLP/core/dataset/field.py | 4 +++ fastNLP/io/data_bundle.py | 34 ++++++++++++++++++ tests/io/test_data_bundle.py | 64 +++++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 tests/io/test_data_bundle.py diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 438d84b6..d57113aa 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -864,9 +864,13 @@ class DataSet: results = [ins for ins in self if not func(ins)] if len(results) != 0: dataset = DataSet(results) - return dataset else: - return DataSet() + dataset = DataSet() + for name in self.field_arrays.keys(): + empty_field = FieldArray(name, [None]) + empty_field.content = [] + dataset.field_arrays[name] = empty_field + return dataset def split(self, ratio: float, shuffle=True): r""" diff --git a/fastNLP/core/dataset/field.py b/fastNLP/core/dataset/field.py index e9795885..a272c1ee 100644 --- a/fastNLP/core/dataset/field.py +++ b/fastNLP/core/dataset/field.py @@ -47,6 +47,10 @@ class FieldArray: """ self.content.pop(index) + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + def __getitem__(self, indices: Union[int, List[int]]): return self.get(indices) diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 1a4dff28..f118002f 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -354,6 +354,40 @@ class DataBundle: progress_bar=progress_bar, progress_desc=progress_desc) return res + def add_seq_len(self, field_name: str, new_field_name='seq_len', ignore_miss_dataset: bool = True): + r""" + 将使用 :func:`len` 直接对每个 dataset 的 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 + ``new_field_name`` 这个 field。 + + :param field_name: 需要处理的 field_name + :param new_field_name: 新的 field_name + :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, + 如果为 ``False`` 则会报错。 + :return: + """ + return self.apply_field(len, field_name, new_field_name=new_field_name, ignore_miss_dataset=ignore_miss_dataset) + + def drop(self, func: Callable, inplace=True): + r""" + 删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, + 该 Instance 会被移除或者不会包含在返回的 DataBundle 中。 + + :param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance + :param inplace: 是否在当前 DataBundle 中直接删除 instance;如果为 False,将返回一个新的 DataBundle。 + + :return: DataSet + """ + if inplace: + for name, dataset in self.datasets.items(): + dataset.drop(func, inplace) + return self + else: + data_bundle = DataBundle(vocabs=self.vocabs) + for name, dataset in self.datasets.items(): + res = dataset.drop(func, inplace) + data_bundle.set_dataset(res, name) + return data_bundle + def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 diff --git a/tests/io/test_data_bundle.py b/tests/io/test_data_bundle.py new file mode 100644 index 00000000..4c756694 --- /dev/null +++ b/tests/io/test_data_bundle.py @@ -0,0 +1,64 @@ +import pytest + +from fastNLP.core.dataset import DataSet +from fastNLP.io.data_bundle import DataBundle + +def test_add_seq_len(): + dataset1 = DataSet({ + "x": [[0,1,2], [5,3,2,3], [5,21,5,10], [3,6,8,1]] + }) + dataset2 = DataSet({ + "x": [[0,1,2,3,4], [5,3,2,3], [5,20,45,1,98], [3,6,8,3,6,31]] + }) + dataset3 = DataSet({ + "x": [[0,1,2,7,5,2], [5,3], [0], [3,6,8]] + }) + data_bundle = DataBundle(datasets={ + "dataset1": dataset1, + "dataset2": dataset2, + "dataset3": dataset3 + }) + data_bundle.add_seq_len("x") + print(data_bundle.get_dataset("dataset1")) + for i, data in enumerate(data_bundle.get_dataset("dataset1")): + print(data["seq_len"], dataset1["x"][i]) + assert data["seq_len"] == len(dataset1["x"][i]) + for i, data in enumerate(data_bundle.get_dataset("dataset2")): + assert data["seq_len"] == len(dataset2["x"][i]) + for i, data in enumerate(data_bundle.get_dataset("dataset3")): + assert data["seq_len"] == len(dataset3["x"][i]) + +@pytest.mark.parametrize("inplace", [True, False]) +def test_drop(inplace): + dataset1 = DataSet({ + "x": [0, 1, 1, 4, 2, 1, 0, 1, 1, 6, 7, 1] + }) + dataset2 = DataSet({ + "x": [0, 0, 0, 0, 0] + }) + dataset3 = DataSet({ + "x": [1, 1, 1, 1, 1, 2, 3, 4] + }) + data_bundle = DataBundle(datasets={ + "dataset1": dataset1, + "dataset2": dataset2, + "dataset3": dataset3 + }) + res = data_bundle.drop(lambda x: x["x"] == 0, inplace) + if inplace: + assert res is data_bundle + else: + assert not (res is data_bundle) + assert data_bundle.get_dataset("dataset1")["x"] == dataset1["x"] + assert data_bundle.get_dataset("dataset2")["x"] == dataset2["x"] + assert data_bundle.get_dataset("dataset3")["x"] == dataset3["x"] + + dataset1_drop = [1, 1, 4, 2, 1, 1, 1, 6, 7, 1] + for i, data in enumerate(res.get_dataset("dataset1")["x"]): + assert data == dataset1_drop[i] + dataset2_drop = [] + for i, data in enumerate(res.get_dataset("dataset2")["x"]): + assert data == dataset2_drop[i] + dataset3_drop = [1, 1, 1, 1, 1, 2, 3, 4] + for i, data in enumerate(res.get_dataset("dataset3")["x"]): + assert data == dataset3_drop[i] \ No newline at end of file