diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 77dcb521..0969eeef 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -358,8 +358,26 @@ class CNXNLILoader(Loader): super(CNXNLILoader, self).__init__() def _load(self, path: str = None): - csv_loader = CSVLoader(sep='\t') - ds_all = csv_loader._load(path) + #csv_loader = CSVLoader(sep='\t') + #ds_all = csv_loader._load(path) + ds_all = DataSet() + with open(path, 'r', encoding='utf-8') as f: + head_name_list = f.readline().strip().split('\t') + sentence1_index = head_name_list.index('sentence1') + sentence2_index = head_name_list.index('sentence2') + gold_label_index = head_name_list.index('gold_label') + language_index = head_name_list.index(('language')) + + for line in f: + line = line.strip() + raw_instance = line.split('\t') + sentence1 = raw_instance[sentence1_index] + sentence2 = raw_instance[sentence2_index] + gold_label = raw_instance[gold_label_index] + language = raw_instance[language_index] + if sentence1: + ds_all.append(Instance(sentence1=sentence1, sentence2=sentence2, gold_label=gold_label, language=language)) + ds_zh = DataSet() for i in ds_all: if i['language'] == 'zh': @@ -368,8 +386,20 @@ class CNXNLILoader(Loader): return ds_zh def _load_train(self, path: str = None): - csv_loader = CSVLoader(sep='\t') - ds = csv_loader._load(path) + #csv_loader = CSVLoader(sep='\t') + #ds = csv_loader._load(path) + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + next(f) + for line in f: + raw_instance = line.strip().split('\t') + premise = raw_instance[0] + hypo = raw_instance[1] + label = raw_instance[-1] + if premise: + ds.append(Instance(premise=premise, hypo=hypo, label=label)) + ds.rename_field('label', 'target') ds.rename_field('premise', 'raw_chars1') ds.rename_field('hypo', 'raw_chars2')