Browse Source

Merge pull request #231 from benbijituo/dev0.5.0

修改了CNXNLI的_load(),可以处理特殊的instance格式如下:
tags/v0.4.10
yhcc GitHub 5 years ago
parent
commit
f1c1010a0f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 4 deletions
  1. +34
    -4
      fastNLP/io/loader/matching.py

+ 34
- 4
fastNLP/io/loader/matching.py View File

@@ -358,8 +358,26 @@ class CNXNLILoader(Loader):
super(CNXNLILoader, self).__init__()

def _load(self, path: str = None):
csv_loader = CSVLoader(sep='\t')
ds_all = csv_loader._load(path)
#csv_loader = CSVLoader(sep='\t')
#ds_all = csv_loader._load(path)
ds_all = DataSet()
with open(path, 'r', encoding='utf-8') as f:
head_name_list = f.readline().strip().split('\t')
sentence1_index = head_name_list.index('sentence1')
sentence2_index = head_name_list.index('sentence2')
gold_label_index = head_name_list.index('gold_label')
language_index = head_name_list.index(('language'))

for line in f:
line = line.strip()
raw_instance = line.split('\t')
sentence1 = raw_instance[sentence1_index]
sentence2 = raw_instance[sentence2_index]
gold_label = raw_instance[gold_label_index]
language = raw_instance[language_index]
if sentence1:
ds_all.append(Instance(sentence1=sentence1, sentence2=sentence2, gold_label=gold_label, language=language))

ds_zh = DataSet()
for i in ds_all:
if i['language'] == 'zh':
@@ -368,8 +386,20 @@ class CNXNLILoader(Loader):
return ds_zh

def _load_train(self, path: str = None):
csv_loader = CSVLoader(sep='\t')
ds = csv_loader._load(path)
#csv_loader = CSVLoader(sep='\t')
#ds = csv_loader._load(path)
ds = DataSet()

with open(path, 'r', encoding='utf-8') as f:
next(f)
for line in f:
raw_instance = line.strip().split('\t')
premise = raw_instance[0]
hypo = raw_instance[1]
label = raw_instance[-1]
if premise:
ds.append(Instance(premise=premise, hypo=hypo, label=label))

ds.rename_field('label', 'target')
ds.rename_field('premise', 'raw_chars1')
ds.rename_field('hypo', 'raw_chars2')


Loading…
Cancel
Save