Browse Source

DataSet 增加了 apply_more 和 apply_filed_more 两个新方法

tags/v0.5.5
ChenXin 5 years ago
parent
commit
26ee41f923
2 changed files with 139 additions and 22 deletions
  1. +102
    -22
      fastNLP/core/dataset.py
  2. +37
    -0
      test/core/test_dataset.py

+ 102
- 22
fastNLP/core/dataset.py View File

@@ -284,10 +284,11 @@


""" """
__all__ = [ __all__ = [
"DataSet"
"DataSet",
] ]


import _pickle as pickle import _pickle as pickle
from inspect import isfunction
from copy import deepcopy from copy import deepcopy


import numpy as np import numpy as np
@@ -305,6 +306,12 @@ from .utils import pretty_table_printer
from .collect_fn import Collector from .collect_fn import Collector




class ApplyResultException(Exception):
def __init__(self, msg, index=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
class DataSet(object): class DataSet(object):
""" """
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset`
@@ -780,23 +787,35 @@ class DataSet(object):
assert len(self) != 0, "Null DataSet cannot use apply_field()." assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self: if field_name not in self:
raise KeyError("DataSet has no field named `{}`.".format(field_name)) raise KeyError("DataSet has no field named `{}`.".format(field_name))
results = []
idx = -1
try:
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins[field_name]))
except Exception as e:
if idx != -1:
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs)


if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)
def apply_field_more(self, func, field_name, modify_fields=True, **kwargs):
"""
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。
func 可以返回一个或多个 field 上的结果。
.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param str field_name: 传入func的是哪个field。
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type


return results
1. is_input: bool, 如果为True则将被修改的field设置为input

2. is_target: bool, 如果为True则将被修改的field设置为target

3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型


:return Dict[int:Field]: 返回一个字典
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
raise KeyError("DataSet has no field named `{}`.".format(field_name))
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs)
def _add_apply_field(self, results, new_field_name, kwargs): def _add_apply_field(self, results, new_field_name, kwargs):
""" """
将results作为加入到新的field中,field名称为new_field_name 将results作为加入到新的field中,field名称为new_field_name
@@ -829,12 +848,73 @@ class DataSet(object):
is_target=extra_param.get("is_target", None), is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False)) ignore_type=extra_param.get("ignore_type", False))


def apply_more(self, func, modify_fields=True, **kwargs):
"""
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。
.. note::
``apply_more`` 与 ``apply`` 的区别:
1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果;
2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果;
3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。

:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type

1. is_input: bool, 如果为True则将被修改的的field设置为input

2. is_target: bool, 如果为True则将被修改的的field设置为target

3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型

:return Dict[int:Field]: 返回一个字典
"""
# 返回 dict , 检查是否一直相同
assert isfunction(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
idx = -1
try:
results = {}
for idx, ins in enumerate(self._inner_iter()):
if "_apply_field" in kwargs:
res = func(ins[kwargs["_apply_field"]])
else:
res = func(ins)
if not isinstance(res, dict):
raise ApplyResultException("The result of func is not a dict", idx)
if idx == 0:
for key, value in res.items():
results[key] = [value]
else:
for key, value in res.items():
if key not in results:
raise ApplyResultException("apply results have different fields", idx)
results[key].append(value)
if len(res) != len(results):
raise ApplyResultException("apply results have different fields", idx)
except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
logger.error(e.msg)
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
if modify_fields is True:
for field, result in results.items():
self._add_apply_field(result, field, kwargs)
return results

def apply(self, func, new_field_name=None, **kwargs): def apply(self, func, new_field_name=None, **kwargs):
""" """
将DataSet中每个instance传入到func中,并获取它的返回值. 将DataSet中每个instance传入到func中,并获取它的返回值.


:param callable func: 参数是DataSet中的Instance
:param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆
:param callable func: 参数是 ``DataSet`` 中的 ``Instance``
:param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。 盖之前的field。如果为None则不创建新的field。
:param optional kwargs: 支持输入is_input,is_target,ignore_type :param optional kwargs: 支持输入is_input,is_target,ignore_type


@@ -846,21 +926,21 @@ class DataSet(object):
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
""" """
assert isfunction(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()." assert len(self) != 0, "Null DataSet cannot use apply()."
idx = -1 idx = -1
try: try:
results = [] results = []
for idx, ins in enumerate(self._inner_iter()): for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
if "_apply_field" in kwargs:
results.append(func(ins[kwargs["_apply_field"]]))
else:
results.append(func(ins))
except BaseException as e: except BaseException as e:
if idx != -1: if idx != -1:
logger.error("Exception happens at the `{}`th instance.".format(idx)) logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e raise e


# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))

if new_field_name is not None: if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs) self._add_apply_field(results, new_field_name, kwargs)




+ 37
- 0
test/core/test_dataset.py View File

@@ -3,6 +3,7 @@ import sys
import unittest import unittest


from fastNLP import DataSet from fastNLP import DataSet
from fastNLP.core.dataset import ApplyResultException
from fastNLP import FieldArray from fastNLP import FieldArray
from fastNLP import Instance from fastNLP import Instance
from fastNLP.io import CSVLoader from fastNLP.io import CSVLoader
@@ -143,6 +144,42 @@ class TestDataSetMethods(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
ds.apply(modify_inplace) ds.apply(modify_inplace)


def test_apply_more(self):
T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]})
func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2}
func_2 = lambda x: {"c": x * 3, "d": x ** 3}
def func_err_1(x):
if x["a"] == 1:
return {"e": x["a"] * 2, "f": x["a"] ** 2}
else:
return {"e": x["a"] * 2}
def func_err_2(x):
if x == 1:
return {"e": x * 2, "f": x ** 2}
else:
return {"e": x * 2}
T.apply_more(func_1)
self.assertEqual(list(T["c"]), [2, 4, 6])
self.assertEqual(list(T["d"]), [1, 4, 9])
res = T.apply_field_more(func_2, "a", modify_fields=False)
self.assertEqual(list(T["c"]), [2, 4, 6])
self.assertEqual(list(T["d"]), [1, 4, 9])
self.assertEqual(list(res["c"]), [3, 6, 9])
self.assertEqual(list(res["d"]), [1, 8, 27])
with self.assertRaises(ApplyResultException) as e:
T.apply_more(func_err_1)
print(e)
with self.assertRaises(ApplyResultException) as e:
T.apply_field_more(func_err_2, "a")
print(e)

def test_drop(self): def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)


Loading…
Cancel
Save