Browse Source

1.修改 DataSet.drop,使之在 inplace 为 False 时也能返回空 FieldArray 的数据 2.添加 FieldArray 的迭代方法 3. 增加 DataBunlde 的 drop 和 add_seq_len 函数及测试

pull/11/head
x54-729 1 year ago
parent
commit
67aee2c897
4 changed files with 108 additions and 2 deletions
  1. +6
    -2
      fastNLP/core/dataset/dataset.py
  2. +4
    -0
      fastNLP/core/dataset/field.py
  3. +34
    -0
      fastNLP/io/data_bundle.py
  4. +64
    -0
      tests/io/test_data_bundle.py

+ 6
- 2
fastNLP/core/dataset/dataset.py View File

@@ -864,9 +864,13 @@ class DataSet:
results = [ins for ins in self if not func(ins)]
if len(results) != 0:
dataset = DataSet(results)
return dataset
else:
return DataSet()
dataset = DataSet()
for name in self.field_arrays.keys():
empty_field = FieldArray(name, [None])
empty_field.content = []
dataset.field_arrays[name] = empty_field
return dataset

def split(self, ratio: float, shuffle=True):
r"""


+ 4
- 0
fastNLP/core/dataset/field.py View File

@@ -47,6 +47,10 @@ class FieldArray:
"""
self.content.pop(index)

def __iter__(self):
for idx in range(len(self)):
yield self[idx]

def __getitem__(self, indices: Union[int, List[int]]):
return self.get(indices)



+ 34
- 0
fastNLP/io/data_bundle.py View File

@@ -354,6 +354,40 @@ class DataBundle:
progress_bar=progress_bar, progress_desc=progress_desc)
return res

def add_seq_len(self, field_name: str, new_field_name='seq_len', ignore_miss_dataset: bool = True):
r"""
将使用 :func:`len` 直接对每个 dataset 的 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入
``new_field_name`` 这个 field。

:param field_name: 需要处理的 field_name
:param new_field_name: 新的 field_name
:param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset,
如果为 ``False`` 则会报错。
:return:
"""
return self.apply_field(len, field_name, new_field_name=new_field_name, ignore_miss_dataset=ignore_miss_dataset)

def drop(self, func: Callable, inplace=True):
r"""
删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时,
该 Instance 会被移除或者不会包含在返回的 DataBundle 中。

:param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance
:param inplace: 是否在当前 DataBundle 中直接删除 instance;如果为 False,将返回一个新的 DataBundle。

:return: DataSet
"""
if inplace:
for name, dataset in self.datasets.items():
dataset.drop(func, inplace)
return self
else:
data_bundle = DataBundle(vocabs=self.vocabs)
for name, dataset in self.datasets.items():
res = dataset.drop(func, inplace)
data_bundle.set_dataset(res, name)
return data_bundle

def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。


+ 64
- 0
tests/io/test_data_bundle.py View File

@@ -0,0 +1,64 @@
import pytest

from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle

def test_add_seq_len():
dataset1 = DataSet({
"x": [[0,1,2], [5,3,2,3], [5,21,5,10], [3,6,8,1]]
})
dataset2 = DataSet({
"x": [[0,1,2,3,4], [5,3,2,3], [5,20,45,1,98], [3,6,8,3,6,31]]
})
dataset3 = DataSet({
"x": [[0,1,2,7,5,2], [5,3], [0], [3,6,8]]
})
data_bundle = DataBundle(datasets={
"dataset1": dataset1,
"dataset2": dataset2,
"dataset3": dataset3
})
data_bundle.add_seq_len("x")
print(data_bundle.get_dataset("dataset1"))
for i, data in enumerate(data_bundle.get_dataset("dataset1")):
print(data["seq_len"], dataset1["x"][i])
assert data["seq_len"] == len(dataset1["x"][i])
for i, data in enumerate(data_bundle.get_dataset("dataset2")):
assert data["seq_len"] == len(dataset2["x"][i])
for i, data in enumerate(data_bundle.get_dataset("dataset3")):
assert data["seq_len"] == len(dataset3["x"][i])

@pytest.mark.parametrize("inplace", [True, False])
def test_drop(inplace):
dataset1 = DataSet({
"x": [0, 1, 1, 4, 2, 1, 0, 1, 1, 6, 7, 1]
})
dataset2 = DataSet({
"x": [0, 0, 0, 0, 0]
})
dataset3 = DataSet({
"x": [1, 1, 1, 1, 1, 2, 3, 4]
})
data_bundle = DataBundle(datasets={
"dataset1": dataset1,
"dataset2": dataset2,
"dataset3": dataset3
})
res = data_bundle.drop(lambda x: x["x"] == 0, inplace)
if inplace:
assert res is data_bundle
else:
assert not (res is data_bundle)
assert data_bundle.get_dataset("dataset1")["x"] == dataset1["x"]
assert data_bundle.get_dataset("dataset2")["x"] == dataset2["x"]
assert data_bundle.get_dataset("dataset3")["x"] == dataset3["x"]

dataset1_drop = [1, 1, 4, 2, 1, 1, 1, 6, 7, 1]
for i, data in enumerate(res.get_dataset("dataset1")["x"]):
assert data == dataset1_drop[i]
dataset2_drop = []
for i, data in enumerate(res.get_dataset("dataset2")["x"]):
assert data == dataset2_drop[i]
dataset3_drop = [1, 1, 1, 1, 1, 2, 3, 4]
for i, data in enumerate(res.get_dataset("dataset3")["x"]):
assert data == dataset3_drop[i]

Loading…
Cancel
Save