|
|
@@ -166,7 +166,7 @@ class DataBundle: |
|
|
|
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) |
|
|
|
return self |
|
|
|
|
|
|
|
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): |
|
|
|
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): |
|
|
|
r""" |
|
|
|
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. |
|
|
|
|
|
|
@@ -282,7 +282,7 @@ class DataBundle: |
|
|
|
""" |
|
|
|
return list(self.datasets.keys()) |
|
|
|
|
|
|
|
def get_vocab_names(self)->List[str]: |
|
|
|
def get_vocab_names(self) -> List[str]: |
|
|
|
r""" |
|
|
|
返回DataBundle中Vocabulary的名称 |
|
|
|
|
|
|
@@ -304,9 +304,9 @@ class DataBundle: |
|
|
|
for field_name, vocab in self.vocabs.items(): |
|
|
|
yield field_name, vocab |
|
|
|
|
|
|
|
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): |
|
|
|
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): |
|
|
|
r""" |
|
|
|
对DataBundle中所有的dataset使用apply_field方法 |
|
|
|
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 |
|
|
|
|
|
|
|
:param callable func: input是instance中名为 `field_name` 的field的内容。 |
|
|
|
:param str field_name: 传入func的是哪个field。 |
|
|
@@ -329,8 +329,41 @@ class DataBundle: |
|
|
|
raise KeyError(f"{field_name} not found DataSet:{name}.") |
|
|
|
return self |
|
|
|
|
|
|
|
def apply(self, func, new_field_name:str, **kwargs): |
|
|
|
def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs): |
|
|
|
r""" |
|
|
|
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 |
|
|
|
|
|
|
|
.. note:: |
|
|
|
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 |
|
|
|
``apply`` 区别的介绍。 |
|
|
|
|
|
|
|
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 |
|
|
|
:param str field_name: 传入func的是哪个field。 |
|
|
|
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True |
|
|
|
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; |
|
|
|
如果为False,则报错 |
|
|
|
:param optional kwargs: 支持输入is_input, is_target, ignore_type |
|
|
|
|
|
|
|
1. is_input: bool, 如果为True则将被修改的field设置为input |
|
|
|
|
|
|
|
2. is_target: bool, 如果为True则将被修改的field设置为target |
|
|
|
|
|
|
|
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 |
|
|
|
|
|
|
|
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 |
|
|
|
""" |
|
|
|
res = {} |
|
|
|
for name, dataset in self.datasets.items(): |
|
|
|
if dataset.has_field(field_name=field_name): |
|
|
|
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) |
|
|
|
elif not ignore_miss_dataset: |
|
|
|
raise KeyError(f"{field_name} not found DataSet:{name} .") |
|
|
|
return res |
|
|
|
|
|
|
|
def apply(self, func, new_field_name: str, **kwargs): |
|
|
|
r""" |
|
|
|
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 |
|
|
|
|
|
|
|
对DataBundle中所有的dataset使用apply方法 |
|
|
|
|
|
|
|
:param callable func: input是instance中名为 `field_name` 的field的内容。 |
|
|
@@ -348,6 +381,31 @@ class DataBundle: |
|
|
|
dataset.apply(func, new_field_name=new_field_name, **kwargs) |
|
|
|
return self |
|
|
|
|
|
|
|
def apply_more(self, func, modify_fields=True, **kwargs): |
|
|
|
r""" |
|
|
|
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 |
|
|
|
|
|
|
|
.. note:: |
|
|
|
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 |
|
|
|
``apply`` 区别的介绍。 |
|
|
|
|
|
|
|
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 |
|
|
|
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True |
|
|
|
:param optional kwargs: 支持输入is_input,is_target,ignore_type |
|
|
|
|
|
|
|
1. is_input: bool, 如果为True则将被修改的的field设置为input |
|
|
|
|
|
|
|
2. is_target: bool, 如果为True则将被修改的的field设置为target |
|
|
|
|
|
|
|
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 |
|
|
|
|
|
|
|
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 |
|
|
|
""" |
|
|
|
res = {} |
|
|
|
for name, dataset in self.datasets.items(): |
|
|
|
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) |
|
|
|
return res |
|
|
|
|
|
|
|
def add_collate_fn(self, fn, name=None): |
|
|
|
r""" |
|
|
|
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. |
|
|
@@ -380,5 +438,3 @@ class DataBundle: |
|
|
|
for name, vocab in self.vocabs.items(): |
|
|
|
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) |
|
|
|
return _str |
|
|
|
|
|
|
|
|