diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 20a63d75..9d948ec1 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -34,7 +34,8 @@ class MatchingLoader(DataSetLoader): def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, - cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, + cut_text: int = None, get_index=True, auto_pad_length: int=None, + auto_pad_token: str='', set_input: Union[list, str, bool]=True, set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, @@ -49,6 +50,8 @@ class MatchingLoader(DataSetLoader): :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 :param bool get_index: 是否需要根据词表将文本转为index + :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad + :param str auto_pad_token: 自动pad的内容 :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, 于此同时其他field不会被设置为input。默认值为True。 @@ -169,6 +172,9 @@ class MatchingLoader(DataSetLoader): data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) + if auto_pad_length is not None: + cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) + if cut_text is not None: for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): @@ -180,7 +186,7 @@ class MatchingLoader(DataSetLoader): assert len(data_set_list) > 0, f'There are NO data sets in data info!' if bert_tokenizer is None: - words_vocab = Vocabulary() + words_vocab = Vocabulary(padding=auto_pad_token) words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], field_name=[n for n in data_set_list[0].get_field_names() if (Const.INPUT in n)], @@ -202,6 +208,17 @@ class MatchingLoader(DataSetLoader): data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, is_input=auto_set_input, is_target=auto_set_target) + if auto_pad_length is not None: + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: x[fields] + [words_vocab.padding] * (auto_pad_length - len(x[fields])), + new_field_name=fields, is_input=auto_set_input) + elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): + data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * + (auto_pad_length - len(x[fields])), new_field_name=fields, + is_input=auto_set_input) + for data_name, data_set in data_info.datasets.items(): if isinstance(set_input, list): data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])