|
@@ -19,7 +19,7 @@ class MatchingLoader(DataSetLoader): |
|
|
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 |
|
|
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict = None): |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
self.paths = paths |
|
|
self.paths = paths |
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
def _load(self, path): |
|
@@ -30,11 +30,11 @@ class MatchingLoader(DataSetLoader): |
|
|
""" |
|
|
""" |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str = None, |
|
|
|
|
|
to_lower=False, seq_len_type: str = None, bert_tokenizer: str = None, |
|
|
|
|
|
cut_text: int = None, get_index=True, auto_pad_length: int = None, |
|
|
|
|
|
auto_pad_token: str = '<pad>', set_input: Union[list, str, bool] = True, |
|
|
|
|
|
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool] = None, ) -> DataInfo: |
|
|
|
|
|
|
|
|
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, |
|
|
|
|
|
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, |
|
|
|
|
|
cut_text: int = None, get_index=True, auto_pad_length: int=None, |
|
|
|
|
|
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, |
|
|
|
|
|
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: |
|
|
""" |
|
|
""" |
|
|
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, |
|
|
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, |
|
|
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 |
|
|
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 |
|
@@ -171,7 +171,7 @@ class MatchingLoader(DataSetLoader): |
|
|
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) |
|
|
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) |
|
|
|
|
|
|
|
|
if auto_pad_length is not None: |
|
|
if auto_pad_length is not None: |
|
|
cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) |
|
|
|
|
|
|
|
|
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) |
|
|
|
|
|
|
|
|
if cut_text is not None: |
|
|
if cut_text is not None: |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
@@ -207,10 +207,10 @@ class MatchingLoader(DataSetLoader): |
|
|
is_input=auto_set_input, is_target=auto_set_target) |
|
|
is_input=auto_set_input, is_target=auto_set_target) |
|
|
|
|
|
|
|
|
if auto_pad_length is not None: |
|
|
if auto_pad_length is not None: |
|
|
|
|
|
if seq_len_type == 'seq_len': |
|
|
|
|
|
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' |
|
|
|
|
|
f'so the seq_len_type cannot be `{seq_len_type}`!') |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
if seq_len_type == 'seq_len': |
|
|
|
|
|
raise RuntimeError(f'sequence will be padded with the length {auto_pad_length},' |
|
|
|
|
|
f'the seq_len_type cannot be `{seq_len_type}`!') |
|
|
|
|
|
for fields in data_set.get_field_names(): |
|
|
for fields in data_set.get_field_names(): |
|
|
if Const.INPUT in fields: |
|
|
if Const.INPUT in fields: |
|
|
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * |
|
|
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * |
|
@@ -242,7 +242,7 @@ class SNLILoader(MatchingLoader, JsonLoader): |
|
|
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip |
|
|
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict = None): |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
fields = { |
|
|
fields = { |
|
|
'sentence1_binary_parse': Const.INPUTS(0), |
|
|
'sentence1_binary_parse': Const.INPUTS(0), |
|
|
'sentence2_binary_parse': Const.INPUTS(1), |
|
|
'sentence2_binary_parse': Const.INPUTS(1), |
|
@@ -281,7 +281,7 @@ class RTELoader(MatchingLoader, CSVLoader): |
|
|
数据来源: |
|
|
数据来源: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict = None): |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
paths = paths if paths is not None else { |
|
|
paths = paths if paths is not None else { |
|
|
'train': 'train.tsv', |
|
|
'train': 'train.tsv', |
|
|
'dev': 'dev.tsv', |
|
|
'dev': 'dev.tsv', |
|
@@ -299,7 +299,8 @@ class RTELoader(MatchingLoader, CSVLoader): |
|
|
ds = CSVLoader._load(self, path) |
|
|
ds = CSVLoader._load(self, path) |
|
|
|
|
|
|
|
|
for k, v in self.fields.items(): |
|
|
for k, v in self.fields.items(): |
|
|
ds.rename_field(k, v) |
|
|
|
|
|
|
|
|
if v in ds.get_field_names(): |
|
|
|
|
|
ds.rename_field(k, v) |
|
|
for fields in ds.get_all_fields(): |
|
|
for fields in ds.get_all_fields(): |
|
|
if Const.INPUT in fields: |
|
|
if Const.INPUT in fields: |
|
|
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) |
|
|
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) |
|
@@ -320,7 +321,7 @@ class QNLILoader(MatchingLoader, CSVLoader): |
|
|
数据来源: |
|
|
数据来源: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict = None): |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
paths = paths if paths is not None else { |
|
|
paths = paths if paths is not None else { |
|
|
'train': 'train.tsv', |
|
|
'train': 'train.tsv', |
|
|
'dev': 'dev.tsv', |
|
|
'dev': 'dev.tsv', |
|
@@ -338,7 +339,8 @@ class QNLILoader(MatchingLoader, CSVLoader): |
|
|
ds = CSVLoader._load(self, path) |
|
|
ds = CSVLoader._load(self, path) |
|
|
|
|
|
|
|
|
for k, v in self.fields.items(): |
|
|
for k, v in self.fields.items(): |
|
|
ds.rename_field(k, v) |
|
|
|
|
|
|
|
|
if v in ds.get_field_names(): |
|
|
|
|
|
ds.rename_field(k, v) |
|
|
for fields in ds.get_all_fields(): |
|
|
for fields in ds.get_all_fields(): |
|
|
if Const.INPUT in fields: |
|
|
if Const.INPUT in fields: |
|
|
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) |
|
|
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) |
|
@@ -359,7 +361,7 @@ class MNLILoader(MatchingLoader, CSVLoader): |
|
|
数据来源: |
|
|
数据来源: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict = None): |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
paths = paths if paths is not None else { |
|
|
paths = paths if paths is not None else { |
|
|
'train': 'train.tsv', |
|
|
'train': 'train.tsv', |
|
|
'dev_matched': 'dev_matched.tsv', |
|
|
'dev_matched': 'dev_matched.tsv', |
|
@@ -414,7 +416,7 @@ class QuoraLoader(MatchingLoader, CSVLoader): |
|
|
数据来源: |
|
|
数据来源: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict = None): |
|
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
paths = paths if paths is not None else { |
|
|
paths = paths if paths is not None else { |
|
|
'train': 'train.tsv', |
|
|
'train': 'train.tsv', |
|
|
'dev': 'dev.tsv', |
|
|
'dev': 'dev.tsv', |
|
|