Browse Source

添加了dataset模块

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
6c992b31d7
4 changed files with 1128 additions and 0 deletions
  1. +10
    -0
      fastNLP/core/dataset/__init__.py
  2. +818
    -0
      fastNLP/core/dataset/dataset.py
  3. +229
    -0
      fastNLP/core/dataset/field.py
  4. +71
    -0
      fastNLP/core/dataset/instance.py

+ 10
- 0
fastNLP/core/dataset/__init__.py View File

@@ -0,0 +1,10 @@
__all__ = [
'DataSet',
'FieldArray',
'Instance',
'ApplyResultException'
]

from .dataset import DataSet, ApplyResultException
from .field import FieldArray
from .instance import Instance

+ 818
- 0
fastNLP/core/dataset/dataset.py View File

@@ -0,0 +1,818 @@
r"""

"""
__all__ = [
"DataSet",
"ApplyResultException"
]

import _pickle as pickle
from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any
from functools import partial
import warnings

import numpy as np
from threading import Thread

try:
import multiprocess as mp
from multiprocess import RLock
except:
pass

from .field import FieldArray
from .instance import Instance
from fastNLP.core.utils.utils import pretty_table_printer, deprecated
from fastNLP.core.collators import AutoCollator
from fastNLP.core.utils.rich_progress import f_rich_progress
from fastNLP.core.collators.collator import _MultiCollator


class ApplyResultException(Exception):
def __init__(self, msg, index=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了


def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True,
pipe=None, desc: str = None) -> list:
"""
对数据集进行处理封装函数,以便多进程使用

:param ds: 数据集
:param _apply_field: 需要处理数据集的field_name
:param func: 用户自定义的func
:param pipe: 管道
:param desc: 进度条的描述字符
:param show_progress_bar: 是否展示子进程进度条
:return:
"""
if show_progress_bar:
desc = desc if desc else f"Main"
pg_main = f_rich_progress.add_task(description=desc, total=len(ds), visible=show_progress_bar)
results = []
idx = -1

try:
# for idx, ins in tqdm(enumerate(ds), total=len(ds), position=0, desc=desc, disable=not show_progress_bar):
for idx, ins in enumerate(ds):
if _apply_field is not None:
results.append(func(ins[_apply_field]))
else:
results.append(func(ins))
if pipe is not None:
pipe.send([idx + 1])
if show_progress_bar:
f_rich_progress.update(pg_main, advance=1)

except BaseException as e:
if idx != -1:
print("Exception happens at the `{}`th instance.".format(idx))
raise e
finally:
if show_progress_bar:
f_rich_progress.destroy_task(pg_main)
return results


def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None:
"""
多进程下显示主进程的进度条

:param parent: 进程管道
:param total_len: 数据集总长度
:param desc: 进度条描述符
:param show_progress_bar: 是否展示进度条
:return:
"""
desc = desc if desc else "Main"

main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar)
# pb_main = tqdm(total=total_len, desc=desc, position=0)
nums = 0
while True:
msg = parent.recv()[0]
if msg is not None:
f_rich_progress.update(main_pro, advance=1)
nums += 1

if nums == total_len:
break
# pb_main.close()


class DataSet:
r"""
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset`
"""

def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None):
r"""

:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list,
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。
"""
self.field_arrays = {}
self.collate_fns: _MultiCollator = _MultiCollator(AutoCollator(as_numpy=False))
if data is not None:
if isinstance(data, Dict):
length_set = set()
for key, value in data.items():
length_set.add(len(value))
assert len(length_set) == 1, "Arrays must all be same length."
for key, value in data.items():
self.add_field(field_name=key, fields=value)
elif isinstance(data, List):
for ins in data:
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
self.append(ins)

else:
raise ValueError("data only be dict or list type.")

def __contains__(self, item):
return item in self.field_arrays

def __iter__(self):
for idx in range(len(self)):
yield self[idx]

def _inner_iter(self):
class Iter_ptr:
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = idx

def __getitem__(self, item):
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx]

def __setitem__(self, key, value):
raise TypeError("You cannot modify value directly.")

def items(self):
ins = self.dataset[self.idx]
return ins.items()

def __repr__(self):
return self.dataset[self.idx].__repr__()

def inner_iter_func():
for idx in range(len(self)):
yield Iter_ptr(self, idx)

return inner_iter_func()

def __getitem__(self, idx: Union[int, slice, str, list]):
r"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。

:param idx: can be int or slice.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
"""
if isinstance(idx, int):
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)):
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}")
data_set = DataSet()
for field_name, field in self.field_arrays.items():
data_set.add_field(field_name=field_name, fields=field.content[idx])
return data_set
elif isinstance(idx, str):
if idx not in self:
raise KeyError("No such field called {} in DataSet.".format(idx))
return self.field_arrays[idx]
elif isinstance(idx, list):
dataset = DataSet()
for i in idx:
assert isinstance(i, int), "Only int index allowed."
instance = self[i]
dataset.append(instance)
return dataset
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __getattribute__(self, item):
return object.__getattribute__(self, item)

def __getattr__(self, item):
# Not tested. Don't use !!
if item == "field_arrays":
raise AttributeError
if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item]

def __setstate__(self, state):
self.__dict__ = state

def __getstate__(self):
return self.__dict__

def __len__(self):
r"""Fetch the length of the dataset.

:return length:
"""
if len(self.field_arrays) == 0:
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)

def __repr__(self):
return str(pretty_table_printer(self))

def append(self, instance: Instance) -> None:
r"""
将一个instance对象append到DataSet后面。

:param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。

"""
if len(self.field_arrays) == 0:
# DataSet has no field yet
for name, field in instance.items():
# field = field.tolist() if isinstance(field, np.ndarray) else field
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来
else:
if len(self.field_arrays) != len(instance.fields):
raise ValueError(
"DataSet object has {} fields, but attempt to append an Instance object with {} fields."
.format(len(self.field_arrays), len(instance.fields)))
for name, field in instance.items():
assert name in self.field_arrays
try:
self.field_arrays[name].append(field)
except Exception as e:
print(f"Cannot append to field:{name}.")
raise e

def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None:
r"""
将fieldarray添加到DataSet中.

:param str field_name: 新加入的field的名称
:param ~fastNLP.core.FieldArray fieldarray: 需要加入DataSet的field的内容
:return:
"""
if not isinstance(fieldarray, FieldArray):
raise TypeError("Only fastNLP.FieldArray supported.")
if len(self) != len(fieldarray):
raise RuntimeError(f"The field to add must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fieldarray)}")
fieldarray.name = field_name
self.field_arrays[field_name] = fieldarray

def add_field(self, field_name: str, fields: list) -> None:
r"""
新增一个field, 需要注意的是fields的长度跟dataset长度一致

:param str field_name: 新增的field的名称
:param list fields: 需要新增的field的内容
"""

if len(self.field_arrays) != 0:
if len(self) != len(fields):
raise RuntimeError(f"The field to add must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fields)}")
self.field_arrays[field_name] = FieldArray(field_name, fields)

def delete_instance(self, index: int):
r"""
删除第index个instance

:param int index: 需要删除的instance的index,序号从0开始。
"""
assert isinstance(index, int), "Only integer supported."
if len(self) <= index:
raise IndexError("{} is too large for as DataSet with {} instances.".format(index, len(self)))
if len(self) == 1:
self.field_arrays.clear()
else:
for field in self.field_arrays.values():
field.pop(index)
return self

def delete_field(self, field_name: str):
r"""
删除名为field_name的field

:param str field_name: 需要删除的field的名称.
"""
if self.has_field(field_name):
self.field_arrays.pop(field_name)
else:
raise KeyError(f"Field:{field_name} not found in DataSet.")
return self

def copy_field(self, field_name: str, new_field_name: str):
r"""
深度copy名为field_name的field到new_field_name

:param str field_name: 需要copy的field。
:param str new_field_name: copy生成的field名称
:return: self
"""
if not self.has_field(field_name):
raise KeyError(f"Field:{field_name} not found in DataSet.")
fieldarray = deepcopy(self.get_field(field_name))
fieldarray.name = new_field_name
self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray)
return self

def has_field(self, field_name: str) -> bool:
r"""
判断DataSet中是否有名为field_name这个field

:param str field_name: field的名称
:return bool: 表示是否有名为field_name这个field
"""
if isinstance(field_name, str):
return field_name in self.field_arrays
return False

def get_field(self, field_name: str) -> FieldArray:
r"""
获取field_name这个field

:param str field_name: field的名称
:return: :class:`~fastNLP.FieldArray`
"""
if field_name not in self.field_arrays:
raise KeyError("Field name {} not found in DataSet".format(field_name))
return self.field_arrays[field_name]

def get_all_fields(self) -> dict:
r"""
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray`

:return dict: 返回如上所述的字典
"""
return self.field_arrays

def get_field_names(self) -> list:
r"""
返回一个list,包含所有 field 的名字

:return list: 返回如上所述的列表
"""
return sorted(self.field_arrays.keys())

def get_length(self) -> int:
r"""
获取DataSet的元素数量

:return: int: DataSet中Instance的个数。
"""
return len(self)

def rename_field(self, field_name: str, new_field_name: str):
r"""
将某个field重新命名.

:param str field_name: 原来的field名称。
:param str new_field_name: 修改为new_name。
"""
if field_name in self.field_arrays:
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
self.field_arrays[new_field_name].name = new_field_name
else:
raise KeyError("DataSet has no field named {}.".format(field_name))
return self

def apply_field(self, func: Union[Callable], field_name: str = None,
new_field_name: str = None, num_proc: int = 0,
progress_desc: str = None, show_progress_bar: bool = True):
r"""
将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。

:param num_proc: 进程的数量
:param field_name: 传入 func 的是哪个 field。
:param func: input是 instance 中名为 `field_name` 的 field 的内容。
:param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆
盖之前的 field。如果为 None 则不创建新的 field。
:param progress_desc: progress_desc 的值,默认为 Main
:param show_progress_bar: 是否展示进度条,默认展示进度条
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if not self.has_field(field_name=field_name):
raise KeyError("DataSet has no field named `{}`.".format(field_name))

try:
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar,
progress_desc=progress_desc, _apply_field=field_name)
except BaseException as e:
raise e

if new_field_name is not None:
self.add_field(field_name=new_field_name, fields=results)
return results

def apply_field_more(self, func: Callable = None, field_name: str = None,
modify_fields: bool = True, num_proc: int = 0,
progress_desc: str = None, show_progress_bar: bool = True):
r"""
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。
func 可以返回一个或多个 field 上的结果。

.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。

:param num_proc: 进程的数量
:param field_name: 传入func的是哪个field。
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param show_progress_bar: 是否显示进度条,默认展示
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符
:return Dict[str:Field]: 返回一个字典
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if not self.has_field(field_name=field_name):
raise KeyError("DataSet has no field named `{}`.".format(field_name))
idx = -1
results = {}
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
show_progress_bar=show_progress_bar, _apply_field=field_name)
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
if not isinstance(apply_out[0], dict):
raise Exception("The result of func is not a dict")

for key, value in apply_out[0].items():
results[key] = [value]
# 尝试合并所有dict数据, idx+1 的原因是第一条数据不可能出现错误,默认第一条数据为准
try:
for idx, per_out in enumerate(apply_out[1:]):
if len(set(results.keys()) - set(per_out.keys())):
raise ApplyResultException("apply results have different fields", idx + 1)
for key, value in per_out.items():
results[key].append(value)

except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
print(e.msg)
print("Exception happens at the `{}`th instance.".format(idx + 1))
raise e

if modify_fields is True:
for field, result in results.items():
self.add_field(field_name=field, fields=result)

return results

def _apply_process(self, num_proc: int = 0, func: Callable = None,
show_progress_bar: bool = True, _apply_field: str = None,
progress_desc: str = 'Main') -> list:
"""
:param num_proc: 进程的数量
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance``
:param _apply_field: 需要传进去func的数据集的field_name
:param show_progress_bar: 是否展示progress进度条,默认为展示
:param progress_desc: 进度条的描述字符,默认为'Main
"""

if num_proc == 0:
results = _apply_single(ds=self, _apply_field=_apply_field, func=func,
desc=progress_desc, show_progress_bar=show_progress_bar)
else:
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2
results = []
if num_proc > len(self):
num_proc = len(self)
print(
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}."
)
# 划分数据集
shard_len = len(self) // num_proc
num_left_sample = len(self) % num_proc
start = 0
shard_data = []
for _i in range(num_proc):
end = shard_len + int(_i<num_left_sample) + start
shard_data.append(self[start:end])
start = end
# 配置管道,线程以实现 main progress 能够实时更新。
parent, child = mp.Pipe()
main_thread = Thread(target=_progress_bar, args=(parent, len(self), progress_desc,
show_progress_bar))
partial_single_map = partial(_apply_single, _apply_field=_apply_field, func=func,
pipe=child, show_progress_bar=False)
# 开启进程池,线程
main_thread.start()
pool = mp.Pool(processes=num_proc)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
main_thread.join()

for async_result in pool_outs:
data = async_result.get()
results.extend(data)
return results

def apply_more(self, func: Callable = None, modify_fields: bool = True,
num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True):
r"""
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。

.. note::
``apply_more`` 与 ``apply`` 的区别:

1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果;

2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果;

3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。

:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param num_proc: 进程的数量
:param show_progress_bar: 是否使用tqd显示预处理进度
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称
:return Dict[str:Field]: 返回一个字典
"""
# 返回 dict , 检查是否一直相同
assert callable(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
assert num_proc >= 0, "num_proc must >= 0"
idx = -1

results = {}
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
show_progress_bar=show_progress_bar)
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
if not isinstance(apply_out[0], dict):
raise Exception("The result of func is not a dict")

for key, value in apply_out[0].items():
results[key] = [value]
# 尝试合并所有dict数据, idx+1 的原因是第一条数据不可能出现错误,已经将第一条数据取出来
try:
for idx, per_out in enumerate(apply_out[1:]):
if len(set(results.keys()) - set(per_out.keys())):
raise ApplyResultException("apply results have different fields", idx + 1)
for key, value in per_out.items():
results[key].append(value)

except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
print(e.msg)
print("Exception happens at the `{}`th instance.".format(idx + 1))
raise e

if modify_fields is True:
for field, result in results.items():
self.add_field(field_name=field, fields=result)

return results

def apply(self, func: Callable = None, new_field_name: str = None,
num_proc: int = 0, show_progress_bar: bool = True, progress_desc: str = ''):
"""

:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。
:param num_proc: 进程的数量。
:param show_progress_bar: 是否显示进度条。
:param progress_desc: progress bar 显示的值,默认为空。
"""
assert callable(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
assert num_proc >= 0, "num_proc must be an integer >= 0."
try:
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar,
progress_desc=progress_desc)
except BaseException as e:
raise e

if new_field_name is not None:
self.add_field(field_name=new_field_name, fields=results)

return results

def add_seq_len(self, field_name: str, new_field_name='seq_len'):
r"""
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。

:param field_name: str.
:param new_field_name: str. 新的field_name
:return:
"""
if self.has_field(field_name=field_name):
self.apply_field(len, field_name, new_field_name=new_field_name)
else:
raise KeyError(f"Field:{field_name} not found.")
return self

def drop(self, func: Callable, inplace=True):
r"""
func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。

:param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance
:param bool inplace: 是否在当前DataSet中直接删除instance;如果为False,将返回一个新的DataSet。

:return: DataSet
"""
if inplace:
results = [ins for ins in self if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]
return self
else:
results = [ins for ins in self if not func(ins)]
if len(results) != 0:
dataset = DataSet(results)
return dataset
else:
return DataSet()

def split(self, ratio: float, shuffle=True):
r"""
将DataSet按照ratio的比例拆分,返回两个DataSet

:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `ratio` 这么多数据,第二个DataSet拥有`(1-ratio)`这么多数据
:param bool shuffle: 在split前是否shuffle一下。为False,返回的第一个dataset就是当前dataset中前`ratio`比例的数据,
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ]
"""
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.'
assert isinstance(ratio, float)
assert 0 < ratio < 1
all_indices = [_ for _ in range(len(self))]
if shuffle:
np.random.shuffle(all_indices)
split = int(ratio * len(self))
if split == 0:
error_msg = f'Dev DataSet has {split} instance after split.'
print(error_msg)
raise IndexError(error_msg)
dev_indices = all_indices[:split]
train_indices = all_indices[split:]
dev_set = DataSet()
train_set = DataSet()
for idx in dev_indices:
dev_set.append(self[idx])
for idx in train_indices:
train_set.append(self[idx])

return dev_set, train_set

def save(self, path: str) -> None:
r"""
保存DataSet.

:param str path: 将DataSet存在哪个路径
"""
with open(path, 'wb') as f:
pickle.dump(self, f)

@staticmethod
def load(path: str):
r"""
从保存的DataSet pickle文件的路径中读取DataSet

:param str path: 从哪里读取DataSet
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。
"""
with open(path, 'rb') as f:
d = pickle.load(f)
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d

def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet':
"""
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有
当前dataset含有field,则会报错。

:param DataSet, dataset: 需要和当前dataset concat的dataset
:param bool, inplace: 是否直接将dataset组合到当前dataset中
:param dict, field_mapping: 当传入的dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的
field名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称

:return: DataSet
"""
assert isinstance(dataset, DataSet), "Can only concat two datasets."

fns_in_this_dataset = set(self.get_field_names())
fns_in_other_dataset = dataset.get_field_names()
reverse_field_mapping = {}
if field_mapping is not None:
fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset]
reverse_field_mapping = {v: k for k, v in field_mapping.items()}
fns_in_other_dataset = set(fns_in_other_dataset)
fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset)

if fn_not_seen:
raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}")

if inplace:
ds = self
else:
ds = deepcopy(self)

for fn in fns_in_this_dataset:
ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content))

return ds

@classmethod
def from_pandas(cls, df):
"""
从pandas.DataFrame中读取数据转为Dataset
:param df:
:return:
"""
df_dict = df.to_dict(orient='list')
return cls(df_dict)

def to_pandas(self):
"""
将dataset转为pandas.DataFrame类型的数据

:return:
"""
import pandas as pd
dict_ = {key: value.content for key, value in self.field_arrays.items()}
return pd.DataFrame.from_dict(dict_)

# TODO 应该有返回值的吧
def to_csv(self, path: str) -> None:
"""
将dataset保存为csv文件

:param path:
:return:
"""

df = self.to_pandas()
df.to_csv(path, encoding="utf-8")

def add_collate_fn(self, collate_fn: Callable) -> None:
"""
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面

:param collate_fn: Callable的函数
:return:
"""
self.collate_fns.add_collator(collate_fn)

def set_collate_fn(self, collate_fn: Callable) -> None:
"""
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate

:param collate_fn:
:return:
"""
self.collate_fns = _MultiCollator(collate_fn)

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names: dataset存在的field_name
:param val: 默认为0
:return:
"""
for field_name in field_names:
self.collate_fns.set_pad_val(field_name, val=val)

def set_input(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
self.collate_fns.set_input(*field_names)

def get_collator(self) -> _MultiCollator:
"""
获取dataset绑定的collate_fn,其中包括auto_collate

:return:
"""
return self.collate_fns

@deprecated()
def set_target(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
self.collate_fns.set_input(*field_names)


class IterableDataset:
pass


+ 229
- 0
fastNLP/core/dataset/field.py View File

@@ -0,0 +1,229 @@
r"""
.. todo::
doc
"""
__all__ = [
'FieldArray'
]

from collections import Counter
from typing import Any, Union, List, Callable

import numpy as np


class FieldArray:

def __init__(self, name: str, content):
if len(content) == 0:
raise RuntimeError("Empty fieldarray is not allowed.")
_content = content
try:
_content = list(_content)
except BaseException as e:
print(f"Cannot convert content(of type:{type(content)}) into list.")
raise e
self.name = name
self.content = _content

def append(self, val: Any) -> None:
r"""
:param val: 把该val append到fieldarray。
:return:
"""
self.content.append(val)

def pop(self, index: int) -> None:
r"""
删除该field中index处的元素
:param int index: 从0开始的数据下标。
:return:
"""
self.content.pop(index)

def __getitem__(self, indices: Union[int, List[int]]):
return self.get(indices)

def __setitem__(self, idx: int, val: Any):
assert isinstance(idx, int)
if idx == -1:
idx = len(self) - 1
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}"
self.content[idx] = val

def get(self, indices: Union[int, List[int]]):
r"""
根据给定的indices返回内容。

:param int,List[int] indices: 获取indices对应的内容。
:return: 根据给定的indices返回的内容,可能是单个值或ndarray
"""
if isinstance(indices, int):
if indices == -1:
indices = len(self) - 1
assert 0 <= indices < len(self)
return self.content[indices]
try:
contents = [self.content[i] for i in indices]
except BaseException as e:
raise e
return np.array(contents)

def __len__(self):
r"""
Returns the size of FieldArray.

:return int length:
"""
return len(self.content)

def split(self, sep: str = None, inplace: bool = True):
r"""
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值

:param sep: 分割符,如果为None则直接调用str.split()。
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。
:return: List[List[str]] or self
"""
new_contents = []
for index, cell in enumerate(self.content):
try:
new_contents.append(cell.split(sep))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)

def int(self, inplace: bool = True):
r"""
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)

:param inplace: 如果为True,则将新生成值替换本field。否则返回list。
:return: List[int], List[List[int]], self
"""
new_contents = []
for index, cell in enumerate(self.content):
try:
if isinstance(cell, list):
new_contents.append([int(value) for value in cell])
else:
new_contents.append(int(cell))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)

def float(self, inplace=True):
r"""
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)

:param inplace: 如果为True,则将新生成值替换本field。否则返回list。
:return:
"""
new_contents = []
for index, cell in enumerate(self.content):
try:
if isinstance(cell, list):
new_contents.append([float(value) for value in cell])
else:
new_contents.append(float(cell))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)

def bool(self, inplace=True):
r"""
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)

:param inplace: 如果为True,则将新生成值替换本field。否则返回list。
:return:
"""
new_contents = []
for index, cell in enumerate(self.content):
try:
if isinstance(cell, list):
new_contents.append([bool(value) for value in cell])
else:
new_contents.append(bool(cell))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
raise e

return self._after_process(new_contents, inplace=inplace)

def lower(self, inplace=True):
r"""
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)

:param inplace: 如果为True,则将新生成值替换本field。否则返回list。
:return: List[int], List[List[int]], self
"""
new_contents = []
for index, cell in enumerate(self.content):
try:
if isinstance(cell, list):
new_contents.append([value.lower() for value in cell])
else:
new_contents.append(cell.lower())
except Exception as e:
print(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)

def upper(self, inplace=True):
r"""
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)

:param inplace: 如果为True,则将新生成值替换本field。否则返回list。
:return: List[int], List[List[int]], self
"""
new_contents = []
for index, cell in enumerate(self.content):
try:
if isinstance(cell, list):
new_contents.append([value.upper() for value in cell])
else:
new_contents.append(cell.upper())
except Exception as e:
print(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)

def value_count(self):
r"""
返回该field下不同value的数量。多用于统计label数量

:return: Counter, key是label,value是出现次数
"""
count = Counter()

def cum(cells):
if isinstance(cells, Callable) and not isinstance(cells, str):
for cell_ in cells:
cum(cell_)
else:
count[cells] += 1

for cell in self.content:
cum(cell)
return count

def _after_process(self, new_contents: list, inplace: bool):
r"""
当调用处理函数之后,决定是否要替换field。

:param new_contents:
:param inplace:
:return: self或者生成的content
"""
if inplace:
self.content = new_contents
return self
else:
return new_contents

+ 71
- 0
fastNLP/core/dataset/instance.py View File

@@ -0,0 +1,71 @@
r"""
instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格

"""

__all__ = [
"Instance"
]

from fastNLP.core.utils.utils import pretty_table_printer


class Instance:
r"""
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示::

"""

def __init__(self, **fields):

self.fields = fields

def add_field(self, field_name: str, field: any):
r"""
向Instance中增加一个field

:param str field_name: 新增field的名称
:param Any field: 新增field的内容
"""
self.fields[field_name] = field

def items(self):
r"""
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value

:return: 一个迭代器
"""
return self.fields.items()

def keys(self):
r"""
返回一个迭代器,内容是field_name

:return: 一个迭代器
"""
return self.fields.keys()

def values(self):
r"""
返回一个迭代器,内容是field_value

:return: 一个迭代器
"""
return self.fields.values()

def __contains__(self, item):
return item in self.fields

def __getitem__(self, name):
if name in self.fields:
return self.fields[name]
else:
raise KeyError("{} not found".format(name))

def __setitem__(self, name, field):
return self.add_field(name, field)

def __repr__(self):
return str(pretty_table_printer(self))

Loading…
Cancel
Save