From 8584ef9be54a9e9b55ca24afe695a4e1e545dc31 Mon Sep 17 00:00:00 2001 From: benbijituo Date: Thu, 26 Sep 2019 00:26:02 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86CNXNLI=E7=9A=84=5Flo?= =?UTF-8?q?ad()=EF=BC=8C=E5=8F=AF=E4=BB=A5=E5=A4=84=E7=90=86=E7=89=B9?= =?UTF-8?q?=E6=AE=8A=E7=9A=84instance=E6=A0=BC=E5=BC=8F=E5=A6=82=E4=B8=8B?= =?UTF-8?q?=EF=BC=9A=20=E2=80=9CXXX\t"XXX\tXXX?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/loader/matching.py | 38 +++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 77dcb521..0969eeef 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -358,8 +358,26 @@ class CNXNLILoader(Loader): super(CNXNLILoader, self).__init__() def _load(self, path: str = None): - csv_loader = CSVLoader(sep='\t') - ds_all = csv_loader._load(path) + #csv_loader = CSVLoader(sep='\t') + #ds_all = csv_loader._load(path) + ds_all = DataSet() + with open(path, 'r', encoding='utf-8') as f: + head_name_list = f.readline().strip().split('\t') + sentence1_index = head_name_list.index('sentence1') + sentence2_index = head_name_list.index('sentence2') + gold_label_index = head_name_list.index('gold_label') + language_index = head_name_list.index(('language')) + + for line in f: + line = line.strip() + raw_instance = line.split('\t') + sentence1 = raw_instance[sentence1_index] + sentence2 = raw_instance[sentence2_index] + gold_label = raw_instance[gold_label_index] + language = raw_instance[language_index] + if sentence1: + ds_all.append(Instance(sentence1=sentence1, sentence2=sentence2, gold_label=gold_label, language=language)) + ds_zh = DataSet() for i in ds_all: if i['language'] == 'zh': @@ -368,8 +386,20 @@ class CNXNLILoader(Loader): return ds_zh def _load_train(self, path: str = None): - csv_loader = CSVLoader(sep='\t') - ds = csv_loader._load(path) + #csv_loader = CSVLoader(sep='\t') + #ds = csv_loader._load(path) + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + next(f) + for line in f: + raw_instance = line.strip().split('\t') + premise = raw_instance[0] + hypo = raw_instance[1] + label = raw_instance[-1] + if premise: + ds.append(Instance(premise=premise, hypo=hypo, label=label)) + ds.rename_field('label', 'target') ds.rename_field('premise', 'raw_chars1') ds.rename_field('hypo', 'raw_chars2')