|
|
@@ -284,10 +284,11 @@ |
|
|
|
|
|
|
|
""" |
|
|
|
__all__ = [ |
|
|
|
"DataSet" |
|
|
|
"DataSet", |
|
|
|
] |
|
|
|
|
|
|
|
import _pickle as pickle |
|
|
|
from inspect import isfunction |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
import numpy as np |
|
|
@@ -305,6 +306,12 @@ from .utils import pretty_table_printer |
|
|
|
from .collect_fn import Collector |
|
|
|
|
|
|
|
|
|
|
|
class ApplyResultException(Exception): |
|
|
|
def __init__(self, msg, index=None): |
|
|
|
super().__init__(msg) |
|
|
|
self.msg = msg |
|
|
|
self.index = index # 标示在哪个数据遭遇到问题了 |
|
|
|
|
|
|
|
class DataSet(object): |
|
|
|
""" |
|
|
|
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` |
|
|
@@ -780,23 +787,35 @@ class DataSet(object): |
|
|
|
assert len(self) != 0, "Null DataSet cannot use apply_field()." |
|
|
|
if field_name not in self: |
|
|
|
raise KeyError("DataSet has no field named `{}`.".format(field_name)) |
|
|
|
results = [] |
|
|
|
idx = -1 |
|
|
|
try: |
|
|
|
for idx, ins in enumerate(self._inner_iter()): |
|
|
|
results.append(func(ins[field_name])) |
|
|
|
except Exception as e: |
|
|
|
if idx != -1: |
|
|
|
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) |
|
|
|
raise e |
|
|
|
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None |
|
|
|
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) |
|
|
|
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) |
|
|
|
|
|
|
|
if new_field_name is not None: |
|
|
|
self._add_apply_field(results, new_field_name, kwargs) |
|
|
|
def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): |
|
|
|
""" |
|
|
|
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 |
|
|
|
func 可以返回一个或多个 field 上的结果。 |
|
|
|
|
|
|
|
.. note:: |
|
|
|
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 |
|
|
|
``apply`` 区别的介绍。 |
|
|
|
|
|
|
|
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 |
|
|
|
:param str field_name: 传入func的是哪个field。 |
|
|
|
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True |
|
|
|
:param optional kwargs: 支持输入is_input,is_target,ignore_type |
|
|
|
|
|
|
|
return results |
|
|
|
1. is_input: bool, 如果为True则将被修改的field设置为input |
|
|
|
|
|
|
|
2. is_target: bool, 如果为True则将被修改的field设置为target |
|
|
|
|
|
|
|
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 |
|
|
|
|
|
|
|
:return Dict[int:Field]: 返回一个字典 |
|
|
|
""" |
|
|
|
assert len(self) != 0, "Null DataSet cannot use apply_field()." |
|
|
|
if field_name not in self: |
|
|
|
raise KeyError("DataSet has no field named `{}`.".format(field_name)) |
|
|
|
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) |
|
|
|
|
|
|
|
def _add_apply_field(self, results, new_field_name, kwargs): |
|
|
|
""" |
|
|
|
将results作为加入到新的field中,field名称为new_field_name |
|
|
@@ -829,12 +848,73 @@ class DataSet(object): |
|
|
|
is_target=extra_param.get("is_target", None), |
|
|
|
ignore_type=extra_param.get("ignore_type", False)) |
|
|
|
|
|
|
|
def apply_more(self, func, modify_fields=True, **kwargs): |
|
|
|
""" |
|
|
|
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 |
|
|
|
|
|
|
|
.. note:: |
|
|
|
``apply_more`` 与 ``apply`` 的区别: |
|
|
|
|
|
|
|
1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果; |
|
|
|
|
|
|
|
2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果; |
|
|
|
|
|
|
|
3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 |
|
|
|
|
|
|
|
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 |
|
|
|
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True |
|
|
|
:param optional kwargs: 支持输入is_input,is_target,ignore_type |
|
|
|
|
|
|
|
1. is_input: bool, 如果为True则将被修改的的field设置为input |
|
|
|
|
|
|
|
2. is_target: bool, 如果为True则将被修改的的field设置为target |
|
|
|
|
|
|
|
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 |
|
|
|
|
|
|
|
:return Dict[int:Field]: 返回一个字典 |
|
|
|
""" |
|
|
|
# 返回 dict , 检查是否一直相同 |
|
|
|
assert isfunction(func), "The func you provide is not callable." |
|
|
|
assert len(self) != 0, "Null DataSet cannot use apply()." |
|
|
|
idx = -1 |
|
|
|
try: |
|
|
|
results = {} |
|
|
|
for idx, ins in enumerate(self._inner_iter()): |
|
|
|
if "_apply_field" in kwargs: |
|
|
|
res = func(ins[kwargs["_apply_field"]]) |
|
|
|
else: |
|
|
|
res = func(ins) |
|
|
|
if not isinstance(res, dict): |
|
|
|
raise ApplyResultException("The result of func is not a dict", idx) |
|
|
|
if idx == 0: |
|
|
|
for key, value in res.items(): |
|
|
|
results[key] = [value] |
|
|
|
else: |
|
|
|
for key, value in res.items(): |
|
|
|
if key not in results: |
|
|
|
raise ApplyResultException("apply results have different fields", idx) |
|
|
|
results[key].append(value) |
|
|
|
if len(res) != len(results): |
|
|
|
raise ApplyResultException("apply results have different fields", idx) |
|
|
|
except Exception as e: |
|
|
|
if idx != -1: |
|
|
|
if isinstance(e, ApplyResultException): |
|
|
|
logger.error(e.msg) |
|
|
|
logger.error("Exception happens at the `{}`th instance.".format(idx)) |
|
|
|
raise e |
|
|
|
|
|
|
|
if modify_fields is True: |
|
|
|
for field, result in results.items(): |
|
|
|
self._add_apply_field(result, field, kwargs) |
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
def apply(self, func, new_field_name=None, **kwargs): |
|
|
|
""" |
|
|
|
将DataSet中每个instance传入到func中,并获取它的返回值. |
|
|
|
|
|
|
|
:param callable func: 参数是DataSet中的Instance |
|
|
|
:param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆 |
|
|
|
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` |
|
|
|
:param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 |
|
|
|
盖之前的field。如果为None则不创建新的field。 |
|
|
|
:param optional kwargs: 支持输入is_input,is_target,ignore_type |
|
|
|
|
|
|
@@ -846,21 +926,21 @@ class DataSet(object): |
|
|
|
|
|
|
|
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 |
|
|
|
""" |
|
|
|
assert isfunction(func), "The func you provide is not callable." |
|
|
|
assert len(self) != 0, "Null DataSet cannot use apply()." |
|
|
|
idx = -1 |
|
|
|
try: |
|
|
|
results = [] |
|
|
|
for idx, ins in enumerate(self._inner_iter()): |
|
|
|
results.append(func(ins)) |
|
|
|
if "_apply_field" in kwargs: |
|
|
|
results.append(func(ins[kwargs["_apply_field"]])) |
|
|
|
else: |
|
|
|
results.append(func(ins)) |
|
|
|
except BaseException as e: |
|
|
|
if idx != -1: |
|
|
|
logger.error("Exception happens at the `{}`th instance.".format(idx)) |
|
|
|
raise e |
|
|
|
|
|
|
|
# results = [func(ins) for ins in self._inner_iter()] |
|
|
|
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None |
|
|
|
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) |
|
|
|
|
|
|
|
if new_field_name is not None: |
|
|
|
self._add_apply_field(results, new_field_name, kwargs) |
|
|
|
|
|
|
|