From 00cf9820a216df98b573f6a7f2bc841879b7e2ce Mon Sep 17 00:00:00 2001 From: xuyige Date: Fri, 5 Jul 2019 00:19:02 +0800 Subject: [PATCH] fix a bug in matching loader --- fastNLP/io/data_loader/matching.py | 36 ++++++++++--------- .../matching/data/MatchingDataLoader.py | 29 +++++++++++---- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 70a683f2..1cde950f 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -19,7 +19,7 @@ class MatchingLoader(DataSetLoader): :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 """ - def __init__(self, paths: dict = None): + def __init__(self, paths: dict=None): self.paths = paths def _load(self, path): @@ -30,11 +30,11 @@ class MatchingLoader(DataSetLoader): """ raise NotImplementedError - def process(self, paths: Union[str, Dict[str, str]], dataset_name: str = None, - to_lower=False, seq_len_type: str = None, bert_tokenizer: str = None, - cut_text: int = None, get_index=True, auto_pad_length: int = None, - auto_pad_token: str = '', set_input: Union[list, str, bool] = True, - set_target: Union[list, str, bool] = True, concat: Union[str, list, bool] = None, ) -> DataInfo: + def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, + to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, + cut_text: int = None, get_index=True, auto_pad_length: int=None, + auto_pad_token: str='', set_input: Union[list, str, bool]=True, + set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 @@ -171,7 +171,7 @@ class MatchingLoader(DataSetLoader): new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) if auto_pad_length is not None: - cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) + cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) if cut_text is not None: for data_name, data_set in data_info.datasets.items(): @@ -207,10 +207,10 @@ class MatchingLoader(DataSetLoader): is_input=auto_set_input, is_target=auto_set_target) if auto_pad_length is not None: + if seq_len_type == 'seq_len': + raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' + f'so the seq_len_type cannot be `{seq_len_type}`!') for data_name, data_set in data_info.datasets.items(): - if seq_len_type == 'seq_len': - raise RuntimeError(f'sequence will be padded with the length {auto_pad_length},' - f'the seq_len_type cannot be `{seq_len_type}`!') for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * @@ -242,7 +242,7 @@ class SNLILoader(MatchingLoader, JsonLoader): 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip """ - def __init__(self, paths: dict = None): + def __init__(self, paths: dict=None): fields = { 'sentence1_binary_parse': Const.INPUTS(0), 'sentence2_binary_parse': Const.INPUTS(1), @@ -281,7 +281,7 @@ class RTELoader(MatchingLoader, CSVLoader): 数据来源: """ - def __init__(self, paths: dict = None): + def __init__(self, paths: dict=None): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', @@ -299,7 +299,8 @@ class RTELoader(MatchingLoader, CSVLoader): ds = CSVLoader._load(self, path) for k, v in self.fields.items(): - ds.rename_field(k, v) + if v in ds.get_field_names(): + ds.rename_field(k, v) for fields in ds.get_all_fields(): if Const.INPUT in fields: ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) @@ -320,7 +321,7 @@ class QNLILoader(MatchingLoader, CSVLoader): 数据来源: """ - def __init__(self, paths: dict = None): + def __init__(self, paths: dict=None): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', @@ -338,7 +339,8 @@ class QNLILoader(MatchingLoader, CSVLoader): ds = CSVLoader._load(self, path) for k, v in self.fields.items(): - ds.rename_field(k, v) + if v in ds.get_field_names(): + ds.rename_field(k, v) for fields in ds.get_all_fields(): if Const.INPUT in fields: ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) @@ -359,7 +361,7 @@ class MNLILoader(MatchingLoader, CSVLoader): 数据来源: """ - def __init__(self, paths: dict = None): + def __init__(self, paths: dict=None): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev_matched': 'dev_matched.tsv', @@ -414,7 +416,7 @@ class QuoraLoader(MatchingLoader, CSVLoader): 数据来源: """ - def __init__(self, paths: dict = None): + def __init__(self, paths: dict=None): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 43f016d6..7c32899c 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader): 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` 读取Matching任务的数据集 + + :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 """ def __init__(self, paths: dict=None): - """ - :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 - """ self.paths = paths def _load(self, path): @@ -173,7 +172,7 @@ class MatchingLoader(DataSetLoader): new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) if auto_pad_length is not None: - cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) + cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) if cut_text is not None: for data_name, data_set in data_info.datasets.items(): @@ -209,6 +208,9 @@ class MatchingLoader(DataSetLoader): is_input=auto_set_input, is_target=auto_set_target) if auto_pad_length is not None: + if seq_len_type == 'seq_len': + raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' + f'so the seq_len_type cannot be `{seq_len_type}`!') for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: @@ -298,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader): ds = CSVLoader._load(self, path) for k, v in self.fields.items(): - ds.rename_field(k, v) + if v in ds.get_field_names(): + ds.rename_field(k, v) for fields in ds.get_all_fields(): if Const.INPUT in fields: ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) @@ -337,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader): ds = CSVLoader._load(self, path) for k, v in self.fields.items(): - ds.rename_field(k, v) + if v in ds.get_field_names(): + ds.rename_field(k, v) for fields in ds.get_all_fields(): if Const.INPUT in fields: ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) @@ -349,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader): """ 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` - 读取SNLI数据集,读取的DataSet包含fields:: + 读取MNLI数据集,读取的DataSet包含fields:: words1: list(str),第一句文本, premise words2: list(str), 第二句文本, hypothesis @@ -401,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader): class QuoraLoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ def __init__(self, paths: dict=None): paths = paths if paths is not None else {