|
|
@@ -358,8 +358,26 @@ class CNXNLILoader(Loader): |
|
|
|
super(CNXNLILoader, self).__init__() |
|
|
|
|
|
|
|
def _load(self, path: str = None): |
|
|
|
csv_loader = CSVLoader(sep='\t') |
|
|
|
ds_all = csv_loader._load(path) |
|
|
|
#csv_loader = CSVLoader(sep='\t') |
|
|
|
#ds_all = csv_loader._load(path) |
|
|
|
ds_all = DataSet() |
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
|
head_name_list = f.readline().strip().split('\t') |
|
|
|
sentence1_index = head_name_list.index('sentence1') |
|
|
|
sentence2_index = head_name_list.index('sentence2') |
|
|
|
gold_label_index = head_name_list.index('gold_label') |
|
|
|
language_index = head_name_list.index(('language')) |
|
|
|
|
|
|
|
for line in f: |
|
|
|
line = line.strip() |
|
|
|
raw_instance = line.split('\t') |
|
|
|
sentence1 = raw_instance[sentence1_index] |
|
|
|
sentence2 = raw_instance[sentence2_index] |
|
|
|
gold_label = raw_instance[gold_label_index] |
|
|
|
language = raw_instance[language_index] |
|
|
|
if sentence1: |
|
|
|
ds_all.append(Instance(sentence1=sentence1, sentence2=sentence2, gold_label=gold_label, language=language)) |
|
|
|
|
|
|
|
ds_zh = DataSet() |
|
|
|
for i in ds_all: |
|
|
|
if i['language'] == 'zh': |
|
|
@@ -368,8 +386,20 @@ class CNXNLILoader(Loader): |
|
|
|
return ds_zh |
|
|
|
|
|
|
|
def _load_train(self, path: str = None): |
|
|
|
csv_loader = CSVLoader(sep='\t') |
|
|
|
ds = csv_loader._load(path) |
|
|
|
#csv_loader = CSVLoader(sep='\t') |
|
|
|
#ds = csv_loader._load(path) |
|
|
|
ds = DataSet() |
|
|
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
|
next(f) |
|
|
|
for line in f: |
|
|
|
raw_instance = line.strip().split('\t') |
|
|
|
premise = raw_instance[0] |
|
|
|
hypo = raw_instance[1] |
|
|
|
label = raw_instance[-1] |
|
|
|
if premise: |
|
|
|
ds.append(Instance(premise=premise, hypo=hypo, label=label)) |
|
|
|
|
|
|
|
ds.rename_field('label', 'target') |
|
|
|
ds.rename_field('premise', 'raw_chars1') |
|
|
|
ds.rename_field('hypo', 'raw_chars2') |
|
|
|