| @@ -56,7 +56,6 @@ class MNLILoader(Loader): | |||||
| with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
| f.readline() # 跳过header | f.readline() # 跳过header | ||||
| if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | ||||
| warnings.warn("RTE's test file has no target.") | |||||
| warnings.warn("MNLI's test file has no target.") | warnings.warn("MNLI's test file has no target.") | ||||
| for line in f: | for line in f: | ||||
| line = line.strip() | line = line.strip() | ||||
| @@ -64,8 +63,9 @@ class MNLILoader(Loader): | |||||
| parts = line.split('\t') | parts = line.split('\t') | ||||
| raw_words1 = parts[8] | raw_words1 = parts[8] | ||||
| raw_words2 = parts[9] | raw_words2 = parts[9] | ||||
| idx = int(parts[0]) | |||||
| if raw_words1 and raw_words2: | if raw_words1 and raw_words2: | ||||
| ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||||
| ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, index=idx)) | |||||
| else: | else: | ||||
| for line in f: | for line in f: | ||||
| line = line.strip() | line = line.strip() | ||||
| @@ -74,8 +74,9 @@ class MNLILoader(Loader): | |||||
| raw_words1 = parts[8] | raw_words1 = parts[8] | ||||
| raw_words2 = parts[9] | raw_words2 = parts[9] | ||||
| target = parts[-1] | target = parts[-1] | ||||
| idx = int(parts[0]) | |||||
| if raw_words1 and raw_words2 and target: | if raw_words1 and raw_words2 and target: | ||||
| ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||||
| ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target, index=idx)) | |||||
| return ds | return ds | ||||
| def load(self, paths: str = None): | def load(self, paths: str = None): | ||||