|
|
@@ -34,7 +34,8 @@ class MatchingLoader(DataSetLoader): |
|
|
|
|
|
|
|
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, |
|
|
|
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, |
|
|
|
cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, |
|
|
|
cut_text: int = None, get_index=True, auto_pad_length: int=None, |
|
|
|
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, |
|
|
|
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: |
|
|
|
""" |
|
|
|
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, |
|
|
@@ -49,6 +50,8 @@ class MatchingLoader(DataSetLoader): |
|
|
|
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 |
|
|
|
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 |
|
|
|
:param bool get_index: 是否需要根据词表将文本转为index |
|
|
|
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad |
|
|
|
:param str auto_pad_token: 自动pad的内容 |
|
|
|
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False |
|
|
|
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, |
|
|
|
于此同时其他field不会被设置为input。默认值为True。 |
|
|
@@ -169,6 +172,9 @@ class MatchingLoader(DataSetLoader): |
|
|
|
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), |
|
|
|
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) |
|
|
|
|
|
|
|
if auto_pad_length is not None: |
|
|
|
cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) |
|
|
|
|
|
|
|
if cut_text is not None: |
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
for fields in data_set.get_field_names(): |
|
|
@@ -180,7 +186,7 @@ class MatchingLoader(DataSetLoader): |
|
|
|
assert len(data_set_list) > 0, f'There are NO data sets in data info!' |
|
|
|
|
|
|
|
if bert_tokenizer is None: |
|
|
|
words_vocab = Vocabulary() |
|
|
|
words_vocab = Vocabulary(padding=auto_pad_token) |
|
|
|
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], |
|
|
|
field_name=[n for n in data_set_list[0].get_field_names() |
|
|
|
if (Const.INPUT in n)], |
|
|
@@ -202,6 +208,17 @@ class MatchingLoader(DataSetLoader): |
|
|
|
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, |
|
|
|
is_input=auto_set_input, is_target=auto_set_target) |
|
|
|
|
|
|
|
if auto_pad_length is not None: |
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
for fields in data_set.get_field_names(): |
|
|
|
if Const.INPUT in fields: |
|
|
|
data_set.apply(lambda x: x[fields] + [words_vocab.padding] * (auto_pad_length - len(x[fields])), |
|
|
|
new_field_name=fields, is_input=auto_set_input) |
|
|
|
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): |
|
|
|
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * |
|
|
|
(auto_pad_length - len(x[fields])), new_field_name=fields, |
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
if isinstance(set_input, list): |
|
|
|
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) |
|
|
|