diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 01e6c8ed..558fe20e 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -16,7 +16,6 @@ __all__ = [ 'CSVLoader', 'JsonLoader', 'ConllLoader', - 'MatchingLoader', 'SNLILoader', 'SSTLoader', 'PeopleDailyCorpusLoader', diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 305143b9..139b1d4f 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -29,8 +29,12 @@ class MatchingLoader(JsonLoader): def process(self, paths: Union[str, Dict[str, str]], dataset_name=None, to_lower=False, char_information=False, seq_len_type: str=None, - bert_tokenizer: str=None, get_index=True, set_input: Union[list, bool]=True, - set_target: Union[list, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: + bert_tokenizer: str=None, get_index=True, set_input: Union[list, str, bool]=True, + set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: + if isinstance(set_input, str): + set_input = [set_input] + if isinstance(set_target, str): + set_target = [set_target] if isinstance(set_input, bool): auto_set_input = set_input else: