Browse Source

fix a bug in matching loader

tags/v0.4.10
xuyige 6 years ago
parent
commit
00cf9820a2
2 changed files with 41 additions and 24 deletions
  1. +19
    -17
      fastNLP/io/data_loader/matching.py
  2. +22
    -7
      reproduction/matching/data/MatchingDataLoader.py

+ 19
- 17
fastNLP/io/data_loader/matching.py View File

@@ -19,7 +19,7 @@ class MatchingLoader(DataSetLoader):
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
""" """


def __init__(self, paths: dict = None):
def __init__(self, paths: dict=None):
self.paths = paths self.paths = paths


def _load(self, path): def _load(self, path):
@@ -30,11 +30,11 @@ class MatchingLoader(DataSetLoader):
""" """
raise NotImplementedError raise NotImplementedError


def process(self, paths: Union[str, Dict[str, str]], dataset_name: str = None,
to_lower=False, seq_len_type: str = None, bert_tokenizer: str = None,
cut_text: int = None, get_index=True, auto_pad_length: int = None,
auto_pad_token: str = '<pad>', set_input: Union[list, str, bool] = True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool] = None, ) -> DataInfo:
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
""" """
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
@@ -171,7 +171,7 @@ class MatchingLoader(DataSetLoader):
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)


if auto_pad_length is not None: if auto_pad_length is not None:
cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0)
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)


if cut_text is not None: if cut_text is not None:
for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
@@ -207,10 +207,10 @@ class MatchingLoader(DataSetLoader):
is_input=auto_set_input, is_target=auto_set_target) is_input=auto_set_input, is_target=auto_set_target)


if auto_pad_length is not None: if auto_pad_length is not None:
if seq_len_type == 'seq_len':
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
f'so the seq_len_type cannot be `{seq_len_type}`!')
for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
if seq_len_type == 'seq_len':
raise RuntimeError(f'sequence will be padded with the length {auto_pad_length},'
f'the seq_len_type cannot be `{seq_len_type}`!')
for fields in data_set.get_field_names(): for fields in data_set.get_field_names():
if Const.INPUT in fields: if Const.INPUT in fields:
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
@@ -242,7 +242,7 @@ class SNLILoader(MatchingLoader, JsonLoader):
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
""" """


def __init__(self, paths: dict = None):
def __init__(self, paths: dict=None):
fields = { fields = {
'sentence1_binary_parse': Const.INPUTS(0), 'sentence1_binary_parse': Const.INPUTS(0),
'sentence2_binary_parse': Const.INPUTS(1), 'sentence2_binary_parse': Const.INPUTS(1),
@@ -281,7 +281,7 @@ class RTELoader(MatchingLoader, CSVLoader):
数据来源: 数据来源:
""" """


def __init__(self, paths: dict = None):
def __init__(self, paths: dict=None):
paths = paths if paths is not None else { paths = paths if paths is not None else {
'train': 'train.tsv', 'train': 'train.tsv',
'dev': 'dev.tsv', 'dev': 'dev.tsv',
@@ -299,7 +299,8 @@ class RTELoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path) ds = CSVLoader._load(self, path)


for k, v in self.fields.items(): for k, v in self.fields.items():
ds.rename_field(k, v)
if v in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields(): for fields in ds.get_all_fields():
if Const.INPUT in fields: if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -320,7 +321,7 @@ class QNLILoader(MatchingLoader, CSVLoader):
数据来源: 数据来源:
""" """


def __init__(self, paths: dict = None):
def __init__(self, paths: dict=None):
paths = paths if paths is not None else { paths = paths if paths is not None else {
'train': 'train.tsv', 'train': 'train.tsv',
'dev': 'dev.tsv', 'dev': 'dev.tsv',
@@ -338,7 +339,8 @@ class QNLILoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path) ds = CSVLoader._load(self, path)


for k, v in self.fields.items(): for k, v in self.fields.items():
ds.rename_field(k, v)
if v in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields(): for fields in ds.get_all_fields():
if Const.INPUT in fields: if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -359,7 +361,7 @@ class MNLILoader(MatchingLoader, CSVLoader):
数据来源: 数据来源:
""" """


def __init__(self, paths: dict = None):
def __init__(self, paths: dict=None):
paths = paths if paths is not None else { paths = paths if paths is not None else {
'train': 'train.tsv', 'train': 'train.tsv',
'dev_matched': 'dev_matched.tsv', 'dev_matched': 'dev_matched.tsv',
@@ -414,7 +416,7 @@ class QuoraLoader(MatchingLoader, CSVLoader):
数据来源: 数据来源:
""" """


def __init__(self, paths: dict = None):
def __init__(self, paths: dict=None):
paths = paths if paths is not None else { paths = paths if paths is not None else {
'train': 'train.tsv', 'train': 'train.tsv',
'dev': 'dev.tsv', 'dev': 'dev.tsv',


+ 22
- 7
reproduction/matching/data/MatchingDataLoader.py View File

@@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader):
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`


读取Matching任务的数据集 读取Matching任务的数据集

:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
""" """


def __init__(self, paths: dict=None): def __init__(self, paths: dict=None):
"""
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
"""
self.paths = paths self.paths = paths


def _load(self, path): def _load(self, path):
@@ -173,7 +172,7 @@ class MatchingLoader(DataSetLoader):
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)


if auto_pad_length is not None: if auto_pad_length is not None:
cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0)
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)


if cut_text is not None: if cut_text is not None:
for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
@@ -209,6 +208,9 @@ class MatchingLoader(DataSetLoader):
is_input=auto_set_input, is_target=auto_set_target) is_input=auto_set_input, is_target=auto_set_target)


if auto_pad_length is not None: if auto_pad_length is not None:
if seq_len_type == 'seq_len':
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
f'so the seq_len_type cannot be `{seq_len_type}`!')
for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names(): for fields in data_set.get_field_names():
if Const.INPUT in fields: if Const.INPUT in fields:
@@ -298,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path) ds = CSVLoader._load(self, path)


for k, v in self.fields.items(): for k, v in self.fields.items():
ds.rename_field(k, v)
if v in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields(): for fields in ds.get_all_fields():
if Const.INPUT in fields: if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -337,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path) ds = CSVLoader._load(self, path)


for k, v in self.fields.items(): for k, v in self.fields.items():
ds.rename_field(k, v)
if v in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields(): for fields in ds.get_all_fields():
if Const.INPUT in fields: if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -349,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader):
""" """
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`


读取SNLI数据集,读取的DataSet包含fields::
读取MNLI数据集,读取的DataSet包含fields::


words1: list(str),第一句文本, premise words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis words2: list(str), 第二句文本, hypothesis
@@ -401,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader):




class QuoraLoader(MatchingLoader, CSVLoader): class QuoraLoader(MatchingLoader, CSVLoader):
"""
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`

读取MNLI数据集,读取的DataSet包含fields::

words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签

数据来源:
"""


def __init__(self, paths: dict=None): def __init__(self, paths: dict=None):
paths = paths if paths is not None else { paths = paths if paths is not None else {


Loading…
Cancel
Save