Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
yh 6 years ago
parent
commit
29ca17d324
1 changed files with 19 additions and 2 deletions
  1. +19
    -2
      reproduction/matching/data/MatchingDataLoader.py

+ 19
- 2
reproduction/matching/data/MatchingDataLoader.py View File

@@ -34,7 +34,8 @@ class MatchingLoader(DataSetLoader):

def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True,
cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
"""
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
@@ -49,6 +50,8 @@ class MatchingLoader(DataSetLoader):
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
:param bool get_index: 是否需要根据词表将文本转为index
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad
:param str auto_pad_token: 自动pad的内容
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
于此同时其他field不会被设置为input。默认值为True。
@@ -169,6 +172,9 @@ class MatchingLoader(DataSetLoader):
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)

if auto_pad_length is not None:
cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0)

if cut_text is not None:
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
@@ -180,7 +186,7 @@ class MatchingLoader(DataSetLoader):
assert len(data_set_list) > 0, f'There are NO data sets in data info!'

if bert_tokenizer is None:
words_vocab = Vocabulary()
words_vocab = Vocabulary(padding=auto_pad_token)
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
field_name=[n for n in data_set_list[0].get_field_names()
if (Const.INPUT in n)],
@@ -202,6 +208,17 @@ class MatchingLoader(DataSetLoader):
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
is_input=auto_set_input, is_target=auto_set_target)

if auto_pad_length is not None:
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
if Const.INPUT in fields:
data_set.apply(lambda x: x[fields] + [words_vocab.padding] * (auto_pad_length - len(x[fields])),
new_field_name=fields, is_input=auto_set_input)
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
(auto_pad_length - len(x[fields])), new_field_name=fields,
is_input=auto_set_input)

for data_name, data_set in data_info.datasets.items():
if isinstance(set_input, list):
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])


Loading…
Cancel
Save