|
@@ -1,6 +1,6 @@ |
|
|
import os |
|
|
import os |
|
|
|
|
|
|
|
|
from typing import Union, Dict , List |
|
|
|
|
|
|
|
|
from typing import Union, Dict, List |
|
|
|
|
|
|
|
|
from ...core.const import Const |
|
|
from ...core.const import Const |
|
|
from ...core.vocabulary import Vocabulary |
|
|
from ...core.vocabulary import Vocabulary |
|
@@ -34,7 +34,7 @@ class MatchingLoader(DataSetLoader): |
|
|
cut_text: int = None, get_index=True, auto_pad_length: int=None, |
|
|
cut_text: int = None, get_index=True, auto_pad_length: int=None, |
|
|
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, |
|
|
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, |
|
|
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, |
|
|
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, |
|
|
extra_split: List[str]=List['-'], ) -> DataBundle: |
|
|
|
|
|
|
|
|
extra_split: List[str]=None, ) -> DataBundle: |
|
|
""" |
|
|
""" |
|
|
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, |
|
|
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, |
|
|
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 |
|
|
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 |
|
@@ -91,22 +91,22 @@ class MatchingLoader(DataSetLoader): |
|
|
if Const.TARGET in data_set.get_field_names(): |
|
|
if Const.TARGET in data_set.get_field_names(): |
|
|
data_set.set_target(Const.TARGET) |
|
|
data_set.set_target(Const.TARGET) |
|
|
|
|
|
|
|
|
if extra_split: |
|
|
|
|
|
|
|
|
if extra_split is not None: |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) |
|
|
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) |
|
|
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) |
|
|
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) |
|
|
|
|
|
|
|
|
for s in extra_split: |
|
|
for s in extra_split: |
|
|
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '), |
|
|
|
|
|
new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '), |
|
|
|
|
|
new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
|
|
|
|
|
|
_filt = lambda x : x |
|
|
|
|
|
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(0)].split(' '))), |
|
|
|
|
|
new_field_name=Const.INPUTS(0), is_input=auto_set_input) |
|
|
|
|
|
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(1)].split(' '))), |
|
|
|
|
|
new_field_name=Const.INPUTS(1), is_input=auto_set_input) |
|
|
|
|
|
|
|
|
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), |
|
|
|
|
|
new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), |
|
|
|
|
|
new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
|
|
|
|
|
|
_filt = lambda x: x |
|
|
|
|
|
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(0)].split(' '))), |
|
|
|
|
|
new_field_name=Const.INPUTS(0), is_input=auto_set_input) |
|
|
|
|
|
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(1)].split(' '))), |
|
|
|
|
|
new_field_name=Const.INPUTS(1), is_input=auto_set_input) |
|
|
_filt = None |
|
|
_filt = None |
|
|
|
|
|
|
|
|
if to_lower: |
|
|
if to_lower: |
|
|