|
@@ -86,7 +86,8 @@ class MatchingLoader(DataSetLoader): |
|
|
if auto_set_input: |
|
|
if auto_set_input: |
|
|
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) |
|
|
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) |
|
|
if auto_set_target: |
|
|
if auto_set_target: |
|
|
data_set.set_target(Const.TARGET) |
|
|
|
|
|
|
|
|
if Const.TARGET in data_set.get_field_names(): |
|
|
|
|
|
data_set.set_target(Const.TARGET) |
|
|
|
|
|
|
|
|
if to_lower: |
|
|
if to_lower: |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
@@ -107,6 +108,13 @@ class MatchingLoader(DataSetLoader): |
|
|
else: |
|
|
else: |
|
|
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") |
|
|
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") |
|
|
|
|
|
|
|
|
|
|
|
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') |
|
|
|
|
|
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: |
|
|
|
|
|
lines = f.readlines() |
|
|
|
|
|
lines = [line.strip() for line in lines] |
|
|
|
|
|
words_vocab.add_word_lst(lines) |
|
|
|
|
|
words_vocab.build_vocab() |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
|
tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
|
|
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
@@ -171,14 +179,7 @@ class MatchingLoader(DataSetLoader): |
|
|
data_set_list = [d for n, d in data_info.datasets.items()] |
|
|
data_set_list = [d for n, d in data_info.datasets.items()] |
|
|
assert len(data_set_list) > 0, f'There are NO data sets in data info!' |
|
|
assert len(data_set_list) > 0, f'There are NO data sets in data info!' |
|
|
|
|
|
|
|
|
if bert_tokenizer is not None: |
|
|
|
|
|
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') |
|
|
|
|
|
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: |
|
|
|
|
|
lines = f.readlines() |
|
|
|
|
|
lines = [line.strip() for line in lines] |
|
|
|
|
|
words_vocab.add_word_lst(lines) |
|
|
|
|
|
words_vocab.build_vocab() |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
if bert_tokenizer is None: |
|
|
words_vocab = Vocabulary() |
|
|
words_vocab = Vocabulary() |
|
|
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], |
|
|
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], |
|
|
field_name=[n for n in data_set_list[0].get_field_names() |
|
|
field_name=[n for n in data_set_list[0].get_field_names() |
|
@@ -186,7 +187,8 @@ class MatchingLoader(DataSetLoader): |
|
|
no_create_entry_dataset=[d for n, d in data_info.datasets.items() |
|
|
no_create_entry_dataset=[d for n, d in data_info.datasets.items() |
|
|
if 'train' not in n]) |
|
|
if 'train' not in n]) |
|
|
target_vocab = Vocabulary(padding=None, unknown=None) |
|
|
target_vocab = Vocabulary(padding=None, unknown=None) |
|
|
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) |
|
|
|
|
|
|
|
|
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], |
|
|
|
|
|
field_name=Const.TARGET) |
|
|
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} |
|
|
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} |
|
|
|
|
|
|
|
|
if get_index: |
|
|
if get_index: |
|
@@ -196,14 +198,15 @@ class MatchingLoader(DataSetLoader): |
|
|
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, |
|
|
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, |
|
|
is_input=auto_set_input) |
|
|
is_input=auto_set_input) |
|
|
|
|
|
|
|
|
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, |
|
|
|
|
|
is_input=auto_set_input, is_target=auto_set_target) |
|
|
|
|
|
|
|
|
if Const.TARGET in data_set.get_field_names(): |
|
|
|
|
|
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, |
|
|
|
|
|
is_input=auto_set_input, is_target=auto_set_target) |
|
|
|
|
|
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
if isinstance(set_input, list): |
|
|
if isinstance(set_input, list): |
|
|
data_set.set_input(*set_input) |
|
|
|
|
|
|
|
|
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) |
|
|
if isinstance(set_target, list): |
|
|
if isinstance(set_target, list): |
|
|
data_set.set_target(*set_target) |
|
|
|
|
|
|
|
|
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) |
|
|
|
|
|
|
|
|
return data_info |
|
|
return data_info |
|
|
|
|
|
|
|
@@ -324,3 +327,65 @@ class QNLILoader(MatchingLoader, CSVLoader): |
|
|
|
|
|
|
|
|
return ds |
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MNLILoader(MatchingLoader, CSVLoader): |
|
|
|
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` |
|
|
|
|
|
|
|
|
|
|
|
读取SNLI数据集,读取的DataSet包含fields:: |
|
|
|
|
|
|
|
|
|
|
|
words1: list(str),第一句文本, premise |
|
|
|
|
|
words2: list(str), 第二句文本, hypothesis |
|
|
|
|
|
target: str, 真实标签 |
|
|
|
|
|
|
|
|
|
|
|
数据来源: |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
|
|
|
paths = paths if paths is not None else { |
|
|
|
|
|
'train': 'train.tsv', |
|
|
|
|
|
'dev_matched': 'dev_matched.tsv', |
|
|
|
|
|
'dev_mismatched': 'dev_mismatched.tsv', |
|
|
|
|
|
'test_matched': 'test_matched.tsv', |
|
|
|
|
|
'test_mismatched': 'test_mismatched.tsv', |
|
|
|
|
|
} |
|
|
|
|
|
MatchingLoader.__init__(self, paths=paths) |
|
|
|
|
|
CSVLoader.__init__(self, sep='\t') |
|
|
|
|
|
self.fields = { |
|
|
|
|
|
'sentence1_binary_parse': Const.INPUTS(0), |
|
|
|
|
|
'sentence2_binary_parse': Const.INPUTS(1), |
|
|
|
|
|
'gold_label': Const.TARGET, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
|
|
|
ds = CSVLoader._load(self, path) |
|
|
|
|
|
|
|
|
|
|
|
for k, v in self.fields.items(): |
|
|
|
|
|
if k in ds.get_field_names(): |
|
|
|
|
|
ds.rename_field(k, v) |
|
|
|
|
|
|
|
|
|
|
|
parentheses_table = str.maketrans({'(': None, ')': None}) |
|
|
|
|
|
|
|
|
|
|
|
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), |
|
|
|
|
|
new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), |
|
|
|
|
|
new_field_name=Const.INPUTS(1)) |
|
|
|
|
|
if Const.TARGET in ds.get_field_names(): |
|
|
|
|
|
ds.drop(lambda x: x[Const.TARGET] == '-') |
|
|
|
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuoraLoader(MatchingLoader, CSVLoader): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
|
|
|
paths = paths if paths is not None else { |
|
|
|
|
|
'train': 'train.tsv', |
|
|
|
|
|
'dev': 'dev.tsv', |
|
|
|
|
|
'test': 'test.tsv', |
|
|
|
|
|
} |
|
|
|
|
|
MatchingLoader.__init__(self, paths=paths) |
|
|
|
|
|
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) |
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
|
|
|
ds = CSVLoader._load(self, path) |
|
|
|
|
|
return ds |