|
|
@@ -1,3 +1,18 @@ |
|
|
|
""" |
|
|
|
fastNLP.core.DataSet的介绍文档 |
|
|
|
|
|
|
|
DataSet是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,每一行是一个instance(或sample),每一列是一个feature。 |
|
|
|
|
|
|
|
csv-table:: |
|
|
|
:header: "Field1", "Field2", "Field3" |
|
|
|
:widths:20, 10, 10 |
|
|
|
|
|
|
|
"This is the first instance", ['This', 'is', 'the', 'first', 'instance'], 5 |
|
|
|
"Second instance", ['Second', 'instance'], 2 |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import _pickle as pickle |
|
|
|
|
|
|
|
import numpy as np |
|
|
@@ -31,7 +46,7 @@ class DataSet(object): |
|
|
|
length_set.add(len(value)) |
|
|
|
assert len(length_set) == 1, "Arrays must all be same length." |
|
|
|
for key, value in data.items(): |
|
|
|
self.add_field(name=key, fields=value) |
|
|
|
self.add_field(field_name=key, fields=value) |
|
|
|
elif isinstance(data, list): |
|
|
|
for ins in data: |
|
|
|
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) |
|
|
@@ -88,7 +103,7 @@ class DataSet(object): |
|
|
|
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") |
|
|
|
data_set = DataSet() |
|
|
|
for field in self.field_arrays.values(): |
|
|
|
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, |
|
|
|
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, |
|
|
|
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) |
|
|
|
return data_set |
|
|
|
elif isinstance(idx, str): |
|
|
@@ -131,7 +146,7 @@ class DataSet(object): |
|
|
|
return "DataSet(" + self.__inner_repr__() + ")" |
|
|
|
|
|
|
|
def append(self, ins): |
|
|
|
"""Add an instance to the DataSet. |
|
|
|
"""将一个instance对象append到DataSet后面。 |
|
|
|
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. |
|
|
|
|
|
|
|
:param ins: an Instance object |
|
|
@@ -151,57 +166,60 @@ class DataSet(object): |
|
|
|
assert name in self.field_arrays |
|
|
|
self.field_arrays[name].append(field) |
|
|
|
|
|
|
|
def add_field(self, name, fields, padder=None, is_input=False, is_target=False, ignore_type=False): |
|
|
|
"""Add a new field to the DataSet. |
|
|
|
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): |
|
|
|
"""新增一个field |
|
|
|
|
|
|
|
:param str name: the name of the field. |
|
|
|
:param fields: a list of int, float, or other objects. |
|
|
|
:param padder: PadBase对象,如何对该Field进行padding。如果为None则使用 |
|
|
|
:param bool is_input: whether this field is model input. |
|
|
|
:param bool is_target: whether this field is label or target. |
|
|
|
:param bool ignore_type: If True, do not perform type check. (Default: False) |
|
|
|
:param str field_name: 新增的field的名称 |
|
|
|
:param list fields: 需要新增的field的内容 |
|
|
|
:param None, Padder padder: 如果为None,则不进行pad。 |
|
|
|
:param bool is_input: 新加入的field是否是input |
|
|
|
:param bool is_target: 新加入的field是否是target |
|
|
|
:param bool ignore_type: 是否忽略对新加入的field的类型检查 |
|
|
|
""" |
|
|
|
if padder is None: |
|
|
|
padder = AutoPadder(pad_val=0) |
|
|
|
|
|
|
|
if len(self.field_arrays) != 0: |
|
|
|
if len(self) != len(fields): |
|
|
|
raise RuntimeError(f"The field to append must have the same size as dataset. " |
|
|
|
f"Dataset size {len(self)} != field size {len(fields)}") |
|
|
|
self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, |
|
|
|
padder=padder, ignore_type=ignore_type) |
|
|
|
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, |
|
|
|
padder=padder, ignore_type=ignore_type) |
|
|
|
|
|
|
|
def delete_field(self, name): |
|
|
|
"""Delete a field based on the field name. |
|
|
|
def delete_field(self, field_name): |
|
|
|
"""删除field |
|
|
|
|
|
|
|
:param name: the name of the field to be deleted. |
|
|
|
:param str field_name: 需要删除的field的名称. |
|
|
|
""" |
|
|
|
self.field_arrays.pop(name) |
|
|
|
self.field_arrays.pop(field_name) |
|
|
|
|
|
|
|
def get_field(self, field_name): |
|
|
|
"""获取field_name这个field |
|
|
|
|
|
|
|
:param str field_name: field的名称 |
|
|
|
:return: FieldArray |
|
|
|
""" |
|
|
|
if field_name not in self.field_arrays: |
|
|
|
raise KeyError("Field name {} not found in DataSet".format(field_name)) |
|
|
|
return self.field_arrays[field_name] |
|
|
|
|
|
|
|
def get_all_fields(self): |
|
|
|
"""Return all the fields with their names. |
|
|
|
"""返回一个dict,key为field_name, value为对应的FieldArray |
|
|
|
|
|
|
|
:return field_arrays: the internal data structure of DataSet. |
|
|
|
:return: dict: |
|
|
|
""" |
|
|
|
return self.field_arrays |
|
|
|
|
|
|
|
def get_length(self): |
|
|
|
"""Fetch the length of the dataset. |
|
|
|
"""获取DataSet的元素数量 |
|
|
|
|
|
|
|
:return length: |
|
|
|
:return: int length: |
|
|
|
""" |
|
|
|
return len(self) |
|
|
|
|
|
|
|
def rename_field(self, old_name, new_name): |
|
|
|
"""Rename a field. |
|
|
|
"""将某个field重新命名. |
|
|
|
|
|
|
|
:param str old_name: |
|
|
|
:param str new_name: |
|
|
|
:param str old_name: 原来的field名称 |
|
|
|
:param str new_name: 修改为new_name |
|
|
|
""" |
|
|
|
if old_name in self.field_arrays: |
|
|
|
self.field_arrays[new_name] = self.field_arrays.pop(old_name) |
|
|
@@ -216,8 +234,8 @@ class DataSet(object): |
|
|
|
dataset.set_target('labels', 'seq_len') # 将labels和seq_len这两个field的target属性设置为True |
|
|
|
dataset.set_target('labels', 'seq_lens', flag=False) # 将labels和seq_len的target属性设置为False |
|
|
|
|
|
|
|
:param field_names: str, field的名称 |
|
|
|
:param flag: bool, 将field_name的target状态设置为flag |
|
|
|
:param str field_names: field的名称 |
|
|
|
:param bool flag: 将field_name的target状态设置为flag |
|
|
|
""" |
|
|
|
assert isinstance(flag, bool), "Only bool type supported." |
|
|
|
for name in field_names: |
|
|
@@ -233,8 +251,8 @@ class DataSet(object): |
|
|
|
dataset.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True |
|
|
|
dataset.set_input('words', flag=False) # 将words这个field的input属性设置为False |
|
|
|
|
|
|
|
:param field_names: str, field的名称 |
|
|
|
:param flag: bool, 将field_name的input状态设置为flag |
|
|
|
:param str field_names: field的名称 |
|
|
|
:param bool flag: 将field_name的input状态设置为flag |
|
|
|
""" |
|
|
|
for name in field_names: |
|
|
|
if name in self.field_arrays: |
|
|
@@ -245,8 +263,8 @@ class DataSet(object): |
|
|
|
def set_ignore_type(self, *field_names, flag=True): |
|
|
|
"""将field_names的ignore_type设置为flag状态 |
|
|
|
|
|
|
|
:param field_names: str, field的名称 |
|
|
|
:param flag: bool, |
|
|
|
:param str field_names: field的名称 |
|
|
|
:param bool flag: 将field_name的ignore_type状态设置为flag |
|
|
|
:return: |
|
|
|
""" |
|
|
|
assert isinstance(flag, bool), "Only bool type supported." |
|
|
@@ -264,8 +282,8 @@ class DataSet(object): |
|
|
|
padder = EngChar2DPadder() |
|
|
|
dataset.set_padder('chars', padder) # 则chars这个field会使用EngChar2DPadder进行pad操作 |
|
|
|
|
|
|
|
:param field_name: str, 设置field的padding方式为padder |
|
|
|
:param padder: (None, PadderBase). 设置为None即删除padder, 即对该field不进行padding操作. |
|
|
|
:param str field_name: 设置field的padding方式为padder |
|
|
|
:param None, Padder padder: 设置为None即删除padder, 即对该field不进行pad操作. |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if field_name not in self.field_arrays: |
|
|
@@ -275,8 +293,8 @@ class DataSet(object): |
|
|
|
def set_pad_val(self, field_name, pad_val): |
|
|
|
"""为某个field设置对应的pad_val. |
|
|
|
|
|
|
|
:param field_name: str,修改该field的pad_val |
|
|
|
:param pad_val: int,该field的padder会以pad_val作为padding index |
|
|
|
:param str field_name: 修改该field的pad_val |
|
|
|
:param int pad_val: 该field的padder会以pad_val作为padding index |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if field_name not in self.field_arrays: |
|
|
@@ -286,7 +304,7 @@ class DataSet(object): |
|
|
|
def get_input_name(self): |
|
|
|
"""返回所有is_input被设置为True的field名称 |
|
|
|
|
|
|
|
:return list, 里面的元素为被设置为input的field名称 |
|
|
|
:return: list, 里面的元素为被设置为input的field名称 |
|
|
|
""" |
|
|
|
return [name for name, field in self.field_arrays.items() if field.is_input] |
|
|
|
|
|
|
@@ -300,15 +318,22 @@ class DataSet(object): |
|
|
|
def apply_field(self, func, field_name, new_field_name=None, **kwargs): |
|
|
|
"""将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值. |
|
|
|
|
|
|
|
:param func: Callable, input是instance的`field_name`这个field. |
|
|
|
:param field_name: str, 传入func的是哪个field. |
|
|
|
:param new_field_name: (str, None). 如果不是None,将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有 |
|
|
|
的field相同,则覆盖之前的field. |
|
|
|
:param **kwargs: 合法的参数有以下三个 |
|
|
|
(1) is_input: bool, 如果为True则将`new_field_name`这个field设置为input |
|
|
|
(2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target |
|
|
|
(3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型 |
|
|
|
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 |
|
|
|
:param callable func: input是instance的`field_name`这个field. |
|
|
|
:param str field_name: 传入func的是哪个field. |
|
|
|
:param str, None new_field_name: 将func返回的内容放入到什么field中 |
|
|
|
|
|
|
|
1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 |
|
|
|
同,则覆盖之前的field |
|
|
|
|
|
|
|
2. None, 不创建新的field |
|
|
|
:param kwargs: 合法的参数有以下三个 |
|
|
|
|
|
|
|
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input |
|
|
|
|
|
|
|
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target |
|
|
|
|
|
|
|
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 |
|
|
|
:return: list(Any), 里面的元素为func的返回值,所以list长度为DataSet的长度 |
|
|
|
|
|
|
|
""" |
|
|
|
assert len(self)!=0, "Null DataSet cannot use apply()." |
|
|
@@ -334,9 +359,9 @@ class DataSet(object): |
|
|
|
def _add_apply_field(self, results, new_field_name, kwargs): |
|
|
|
"""将results作为加入到新的field中,field名称为new_field_name |
|
|
|
|
|
|
|
:param results: List[], 一般是apply*()之后的结果 |
|
|
|
:param new_field_name: str, 新加入的field的名称 |
|
|
|
:param kwargs: dict, 用户apply*()时传入的自定义参数 |
|
|
|
:param list(str) results: 一般是apply*()之后的结果 |
|
|
|
:param str new_field_name: 新加入的field的名称 |
|
|
|
:param dict kwargs: 用户apply*()时传入的自定义参数 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
extra_param = {} |
|
|
@@ -355,23 +380,30 @@ class DataSet(object): |
|
|
|
extra_param['is_target'] = old_field.is_target |
|
|
|
if 'ignore_type' not in extra_param: |
|
|
|
extra_param['ignore_type'] = old_field.ignore_type |
|
|
|
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], |
|
|
|
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param["is_input"], |
|
|
|
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) |
|
|
|
else: |
|
|
|
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), |
|
|
|
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), |
|
|
|
is_target=extra_param.get("is_target", None), |
|
|
|
ignore_type=extra_param.get("ignore_type", False)) |
|
|
|
|
|
|
|
def apply(self, func, new_field_name=None, **kwargs): |
|
|
|
"""将DataSet中每个instance传入到func中,并获取它的返回值. |
|
|
|
|
|
|
|
:param func: Callable, 参数是DataSet中的instance |
|
|
|
:param new_field_name: (None, str). (1) None, 不创建新的field; (2) str,将func的返回值放入这个名为 |
|
|
|
`new_field_name`的新field中,如果名称与已有的field相同,则覆盖之前的field; |
|
|
|
:param callable func: 参数是DataSet中的instance |
|
|
|
:param str, None new_field_name: 将func返回的内容放入到什么field中 |
|
|
|
|
|
|
|
1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 |
|
|
|
同,则覆盖之前的field |
|
|
|
|
|
|
|
2. None, 不创建新的field |
|
|
|
:param kwargs: 合法的参数有以下三个 |
|
|
|
(1) is_input: bool, 如果为True则将`new_field_name`的field设置为input |
|
|
|
(2) is_target: bool, 如果为True则将`new_field_name`的field设置为target |
|
|
|
(3) ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 |
|
|
|
|
|
|
|
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input |
|
|
|
|
|
|
|
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target |
|
|
|
|
|
|
|
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 |
|
|
|
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 |
|
|
|
""" |
|
|
|
assert len(self)!=0, "Null DataSet cannot use apply()." |
|
|
@@ -396,10 +428,10 @@ class DataSet(object): |
|
|
|
def drop(self, func, inplace=True): |
|
|
|
"""func接受一个instance,返回bool值,返回值为True时,该instance会被删除。 |
|
|
|
|
|
|
|
:param func: Callable, 接受一个instance作为参数,返回bool值。为True时删除该instance |
|
|
|
:param inplace: bool, 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet |
|
|
|
:param callable func: 接受一个instance作为参数,返回bool值。为True时删除该instance |
|
|
|
:param bool inplace: 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet |
|
|
|
|
|
|
|
:return: DataSet. |
|
|
|
:return: DataSet |
|
|
|
""" |
|
|
|
if inplace: |
|
|
|
results = [ins for ins in self._inner_iter() if not func(ins)] |
|
|
@@ -408,16 +440,16 @@ class DataSet(object): |
|
|
|
return self |
|
|
|
else: |
|
|
|
results = [ins for ins in self if not func(ins)] |
|
|
|
data = DataSet(results) |
|
|
|
dataset = DataSet(results) |
|
|
|
for field_name, field in self.field_arrays.items(): |
|
|
|
data.field_arrays[field_name].to(field) |
|
|
|
return data |
|
|
|
dataset.field_arrays[field_name].to(field) |
|
|
|
return dataset |
|
|
|
|
|
|
|
def split(self, ratio): |
|
|
|
"""将DataSet按照ratio的比例拆分,返回两个DataSet |
|
|
|
|
|
|
|
:param ratio: float, 0<ratio<1, 返回的第一个DataSet拥有ratio这么多数据,第二个DataSet拥有(1-ratio)这么多数据 |
|
|
|
:return (DataSet, DataSet) |
|
|
|
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有ratio这么多数据,第二个DataSet拥有(1-ratio)这么多数据 |
|
|
|
:return: [DataSet, DataSet] |
|
|
|
""" |
|
|
|
assert isinstance(ratio, float) |
|
|
|
assert 0 < ratio < 1 |
|
|
@@ -480,7 +512,7 @@ class DataSet(object): |
|
|
|
def save(self, path): |
|
|
|
"""保存DataSet. |
|
|
|
|
|
|
|
:param path: str, 将DataSet存在哪个路径 |
|
|
|
:param str path: 将DataSet存在哪个路径 |
|
|
|
""" |
|
|
|
with open(path, 'wb') as f: |
|
|
|
pickle.dump(self, f) |
|
|
@@ -489,8 +521,8 @@ class DataSet(object): |
|
|
|
def load(path): |
|
|
|
"""从保存的DataSet pickle路径中读取DataSet |
|
|
|
|
|
|
|
:param path: str, 读取路径 |
|
|
|
:return DataSet: |
|
|
|
:param str path: 从哪里读取DataSet |
|
|
|
:return: DataSet |
|
|
|
""" |
|
|
|
with open(path, 'rb') as f: |
|
|
|
d = pickle.load(f) |
|
|
|