From 26ee41f923ee29012f7adf461169bd0c2efb8d3d Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 16 Mar 2020 10:42:56 +0800 Subject: [PATCH] =?UTF-8?q?=20DataSet=20=E5=A2=9E=E5=8A=A0=E4=BA=86=20appl?= =?UTF-8?q?y=5Fmore=20=E5=92=8C=20apply=5Ffiled=5Fmore=20=E4=B8=A4?= =?UTF-8?q?=E4=B8=AA=E6=96=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 124 +++++++++++++++++++++++++++++++------- test/core/test_dataset.py | 37 ++++++++++++ 2 files changed, 139 insertions(+), 22 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index c5210169..5025edfe 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -284,10 +284,11 @@ """ __all__ = [ - "DataSet" + "DataSet", ] import _pickle as pickle +from inspect import isfunction from copy import deepcopy import numpy as np @@ -305,6 +306,12 @@ from .utils import pretty_table_printer from .collect_fn import Collector +class ApplyResultException(Exception): + def __init__(self, msg, index=None): + super().__init__(msg) + self.msg = msg + self.index = index # 标示在哪个数据遭遇到问题了 + class DataSet(object): """ fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` @@ -780,23 +787,35 @@ class DataSet(object): assert len(self) != 0, "Null DataSet cannot use apply_field()." if field_name not in self: raise KeyError("DataSet has no field named `{}`.".format(field_name)) - results = [] - idx = -1 - try: - for idx, ins in enumerate(self._inner_iter()): - results.append(func(ins[field_name])) - except Exception as e: - if idx != -1: - logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) - raise e - if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None - raise ValueError("{} always return None.".format(_get_func_signature(func=func))) + return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) - if new_field_name is not None: - self._add_apply_field(results, new_field_name, kwargs) + def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): + """ + 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 + func 可以返回一个或多个 field 上的结果。 + + .. note:: + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param str field_name: 传入func的是哪个field。 + :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param optional kwargs: 支持输入is_input,is_target,ignore_type - return results + 1. is_input: bool, 如果为True则将被修改的field设置为input + + 2. is_target: bool, 如果为True则将被修改的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 + :return Dict[int:Field]: 返回一个字典 + """ + assert len(self) != 0, "Null DataSet cannot use apply_field()." + if field_name not in self: + raise KeyError("DataSet has no field named `{}`.".format(field_name)) + return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) + def _add_apply_field(self, results, new_field_name, kwargs): """ 将results作为加入到新的field中,field名称为new_field_name @@ -829,12 +848,73 @@ class DataSet(object): is_target=extra_param.get("is_target", None), ignore_type=extra_param.get("ignore_type", False)) + def apply_more(self, func, modify_fields=True, **kwargs): + """ + 将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 + + .. note:: + ``apply_more`` 与 ``apply`` 的区别: + + 1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果; + + 2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果; + + 3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True + :param optional kwargs: 支持输入is_input,is_target,ignore_type + + 1. is_input: bool, 如果为True则将被修改的的field设置为input + + 2. is_target: bool, 如果为True则将被修改的的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 + + :return Dict[int:Field]: 返回一个字典 + """ + # 返回 dict , 检查是否一直相同 + assert isfunction(func), "The func you provide is not callable." + assert len(self) != 0, "Null DataSet cannot use apply()." + idx = -1 + try: + results = {} + for idx, ins in enumerate(self._inner_iter()): + if "_apply_field" in kwargs: + res = func(ins[kwargs["_apply_field"]]) + else: + res = func(ins) + if not isinstance(res, dict): + raise ApplyResultException("The result of func is not a dict", idx) + if idx == 0: + for key, value in res.items(): + results[key] = [value] + else: + for key, value in res.items(): + if key not in results: + raise ApplyResultException("apply results have different fields", idx) + results[key].append(value) + if len(res) != len(results): + raise ApplyResultException("apply results have different fields", idx) + except Exception as e: + if idx != -1: + if isinstance(e, ApplyResultException): + logger.error(e.msg) + logger.error("Exception happens at the `{}`th instance.".format(idx)) + raise e + + if modify_fields is True: + for field, result in results.items(): + self._add_apply_field(result, field, kwargs) + + return results + def apply(self, func, new_field_name=None, **kwargs): """ 将DataSet中每个instance传入到func中,并获取它的返回值. - :param callable func: 参数是DataSet中的Instance - :param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆 + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` + :param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 :param optional kwargs: 支持输入is_input,is_target,ignore_type @@ -846,21 +926,21 @@ class DataSet(object): :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 """ + assert isfunction(func), "The func you provide is not callable." assert len(self) != 0, "Null DataSet cannot use apply()." idx = -1 try: results = [] for idx, ins in enumerate(self._inner_iter()): - results.append(func(ins)) + if "_apply_field" in kwargs: + results.append(func(ins[kwargs["_apply_field"]])) + else: + results.append(func(ins)) except BaseException as e: if idx != -1: logger.error("Exception happens at the `{}`th instance.".format(idx)) raise e - # results = [func(ins) for ins in self._inner_iter()] - if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None - raise ValueError("{} always return None.".format(_get_func_signature(func=func))) - if new_field_name is not None: self._add_apply_field(results, new_field_name, kwargs) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index e05148a6..d048191f 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -3,6 +3,7 @@ import sys import unittest from fastNLP import DataSet +from fastNLP.core.dataset import ApplyResultException from fastNLP import FieldArray from fastNLP import Instance from fastNLP.io import CSVLoader @@ -143,6 +144,42 @@ class TestDataSetMethods(unittest.TestCase): with self.assertRaises(TypeError): ds.apply(modify_inplace) + def test_apply_more(self): + + T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]}) + func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2} + func_2 = lambda x: {"c": x * 3, "d": x ** 3} + + def func_err_1(x): + if x["a"] == 1: + return {"e": x["a"] * 2, "f": x["a"] ** 2} + else: + return {"e": x["a"] * 2} + + def func_err_2(x): + if x == 1: + return {"e": x * 2, "f": x ** 2} + else: + return {"e": x * 2} + + T.apply_more(func_1) + self.assertEqual(list(T["c"]), [2, 4, 6]) + self.assertEqual(list(T["d"]), [1, 4, 9]) + + res = T.apply_field_more(func_2, "a", modify_fields=False) + self.assertEqual(list(T["c"]), [2, 4, 6]) + self.assertEqual(list(T["d"]), [1, 4, 9]) + self.assertEqual(list(res["c"]), [3, 6, 9]) + self.assertEqual(list(res["d"]), [1, 8, 27]) + + with self.assertRaises(ApplyResultException) as e: + T.apply_more(func_err_1) + print(e) + + with self.assertRaises(ApplyResultException) as e: + T.apply_field_more(func_err_2, "a") + print(e) + def test_drop(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)