Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
yh 6 years ago
parent
commit
8b8d184026
1 changed files with 93 additions and 14 deletions
  1. +93
    -14
      reproduction/matching/data/MatchingDataLoader.py

+ 93
- 14
reproduction/matching/data/MatchingDataLoader.py View File

@@ -1,13 +1,12 @@


import os import os


from nltk import Tree
from typing import Union, Dict from typing import Union, Dict


from fastNLP.core.const import Const from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo from fastNLP.io.base_loader import DataInfo
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer from fastNLP.modules.encoder._bert import BertTokenizer


@@ -35,7 +34,7 @@ class MatchingLoader(DataSetLoader):


def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
get_index=True, set_input: Union[list, str, bool]=True,
cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
""" """
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
@@ -48,6 +47,7 @@ class MatchingLoader(DataSetLoader):
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
:param bool get_index: 是否需要根据词表将文本转为index :param bool get_index: 是否需要根据词表将文本转为index
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
@@ -161,6 +161,13 @@ class MatchingLoader(DataSetLoader):
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)


if cut_text is not None:
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')):
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields,
is_input=auto_set_input)

data_set_list = [d for n, d in data_info.datasets.items()] data_set_list = [d for n, d in data_info.datasets.items()]
assert len(data_set_list) > 0, f'There are NO data sets in data info!' assert len(data_set_list) > 0, f'There are NO data sets in data info!'


@@ -216,32 +223,104 @@ class SNLILoader(MatchingLoader, JsonLoader):


def __init__(self, paths: dict=None): def __init__(self, paths: dict=None):
fields = { fields = {
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'sentence1_binary_parse': Const.INPUTS(0),
'sentence2_binary_parse': Const.INPUTS(1),
'gold_label': Const.TARGET, 'gold_label': Const.TARGET,
} }
paths = paths if paths is not None else { paths = paths if paths is not None else {
'train': 'snli_1.0_train.jsonl', 'train': 'snli_1.0_train.jsonl',
'dev': 'snli_1.0_dev.jsonl', 'dev': 'snli_1.0_dev.jsonl',
'test': 'snli_1.0_test.jsonl'} 'test': 'snli_1.0_test.jsonl'}
# super(SNLILoader, self).__init__(fields=fields, paths=paths)
MatchingLoader.__init__(self, paths=paths) MatchingLoader.__init__(self, paths=paths)
JsonLoader.__init__(self, fields=fields) JsonLoader.__init__(self, fields=fields)


def _load(self, path): def _load(self, path):
# ds = super(SNLILoader, self)._load(path)
ds = JsonLoader._load(self, path) ds = JsonLoader._load(self, path)


def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()
parentheses_table = str.maketrans({'(': None, ')': None})


ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(1))
ds.drop(lambda x: x[Const.TARGET] == '-') ds.drop(lambda x: x[Const.TARGET] == '-')
return ds return ds




class RTELoader(MatchingLoader, CSVLoader):
"""
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader`

读取RTE数据集,读取的DataSet包含fields::

words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签

数据来源:
"""

def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
# 'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
'sentence1': Const.INPUTS(0),
'sentence2': Const.INPUTS(1),
'label': Const.TARGET,
}
CSVLoader.__init__(self, sep='\t')

def _load(self, path):
ds = CSVLoader._load(self, path)

for k, v in self.fields.items():
ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)

return ds


class QNLILoader(MatchingLoader, CSVLoader):
"""
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader`

读取QNLI数据集,读取的DataSet包含fields::

words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签

数据来源:
"""

def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
# 'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
'question': Const.INPUTS(0),
'sentence': Const.INPUTS(1),
'label': Const.TARGET,
}
CSVLoader.__init__(self, sep='\t')

def _load(self, path):
ds = CSVLoader._load(self, path)

for k, v in self.fields.items():
ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)

return ds



Loading…
Cancel
Save