|
|
@@ -1,13 +1,12 @@ |
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
from nltk import Tree |
|
|
|
from typing import Union, Dict |
|
|
|
|
|
|
|
from fastNLP.core.const import Const |
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
from fastNLP.io.base_loader import DataInfo |
|
|
|
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader |
|
|
|
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader |
|
|
|
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR |
|
|
|
from fastNLP.modules.encoder._bert import BertTokenizer |
|
|
|
|
|
|
@@ -35,7 +34,7 @@ class MatchingLoader(DataSetLoader): |
|
|
|
|
|
|
|
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, |
|
|
|
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, |
|
|
|
get_index=True, set_input: Union[list, str, bool]=True, |
|
|
|
cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, |
|
|
|
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: |
|
|
|
""" |
|
|
|
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, |
|
|
@@ -48,6 +47,7 @@ class MatchingLoader(DataSetLoader): |
|
|
|
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 |
|
|
|
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len |
|
|
|
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 |
|
|
|
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 |
|
|
|
:param bool get_index: 是否需要根据词表将文本转为index |
|
|
|
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False |
|
|
|
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, |
|
|
@@ -161,6 +161,13 @@ class MatchingLoader(DataSetLoader): |
|
|
|
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), |
|
|
|
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) |
|
|
|
|
|
|
|
if cut_text is not None: |
|
|
|
for data_name, data_set in data_info.datasets.items(): |
|
|
|
for fields in data_set.get_field_names(): |
|
|
|
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): |
|
|
|
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, |
|
|
|
is_input=auto_set_input) |
|
|
|
|
|
|
|
data_set_list = [d for n, d in data_info.datasets.items()] |
|
|
|
assert len(data_set_list) > 0, f'There are NO data sets in data info!' |
|
|
|
|
|
|
@@ -216,32 +223,104 @@ class SNLILoader(MatchingLoader, JsonLoader): |
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
|
fields = { |
|
|
|
'sentence1_parse': Const.INPUTS(0), |
|
|
|
'sentence2_parse': Const.INPUTS(1), |
|
|
|
'sentence1_binary_parse': Const.INPUTS(0), |
|
|
|
'sentence2_binary_parse': Const.INPUTS(1), |
|
|
|
'gold_label': Const.TARGET, |
|
|
|
} |
|
|
|
paths = paths if paths is not None else { |
|
|
|
'train': 'snli_1.0_train.jsonl', |
|
|
|
'dev': 'snli_1.0_dev.jsonl', |
|
|
|
'test': 'snli_1.0_test.jsonl'} |
|
|
|
# super(SNLILoader, self).__init__(fields=fields, paths=paths) |
|
|
|
MatchingLoader.__init__(self, paths=paths) |
|
|
|
JsonLoader.__init__(self, fields=fields) |
|
|
|
|
|
|
|
def _load(self, path): |
|
|
|
# ds = super(SNLILoader, self)._load(path) |
|
|
|
ds = JsonLoader._load(self, path) |
|
|
|
|
|
|
|
def parse_tree(x): |
|
|
|
t = Tree.fromstring(x) |
|
|
|
return t.leaves() |
|
|
|
parentheses_table = str.maketrans({'(': None, ')': None}) |
|
|
|
|
|
|
|
ds.apply(lambda ins: parse_tree( |
|
|
|
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) |
|
|
|
ds.apply(lambda ins: parse_tree( |
|
|
|
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) |
|
|
|
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), |
|
|
|
new_field_name=Const.INPUTS(0)) |
|
|
|
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), |
|
|
|
new_field_name=Const.INPUTS(1)) |
|
|
|
ds.drop(lambda x: x[Const.TARGET] == '-') |
|
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
class RTELoader(MatchingLoader, CSVLoader): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` |
|
|
|
|
|
|
|
读取RTE数据集,读取的DataSet包含fields:: |
|
|
|
|
|
|
|
words1: list(str),第一句文本, premise |
|
|
|
words2: list(str), 第二句文本, hypothesis |
|
|
|
target: str, 真实标签 |
|
|
|
|
|
|
|
数据来源: |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
|
paths = paths if paths is not None else { |
|
|
|
'train': 'train.tsv', |
|
|
|
'dev': 'dev.tsv', |
|
|
|
# 'test': 'test.tsv' # test set has not label |
|
|
|
} |
|
|
|
MatchingLoader.__init__(self, paths=paths) |
|
|
|
self.fields = { |
|
|
|
'sentence1': Const.INPUTS(0), |
|
|
|
'sentence2': Const.INPUTS(1), |
|
|
|
'label': Const.TARGET, |
|
|
|
} |
|
|
|
CSVLoader.__init__(self, sep='\t') |
|
|
|
|
|
|
|
def _load(self, path): |
|
|
|
ds = CSVLoader._load(self, path) |
|
|
|
|
|
|
|
for k, v in self.fields.items(): |
|
|
|
ds.rename_field(k, v) |
|
|
|
for fields in ds.get_all_fields(): |
|
|
|
if Const.INPUT in fields: |
|
|
|
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) |
|
|
|
|
|
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
class QNLILoader(MatchingLoader, CSVLoader): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` |
|
|
|
|
|
|
|
读取QNLI数据集,读取的DataSet包含fields:: |
|
|
|
|
|
|
|
words1: list(str),第一句文本, premise |
|
|
|
words2: list(str), 第二句文本, hypothesis |
|
|
|
target: str, 真实标签 |
|
|
|
|
|
|
|
数据来源: |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, paths: dict=None): |
|
|
|
paths = paths if paths is not None else { |
|
|
|
'train': 'train.tsv', |
|
|
|
'dev': 'dev.tsv', |
|
|
|
# 'test': 'test.tsv' # test set has not label |
|
|
|
} |
|
|
|
MatchingLoader.__init__(self, paths=paths) |
|
|
|
self.fields = { |
|
|
|
'question': Const.INPUTS(0), |
|
|
|
'sentence': Const.INPUTS(1), |
|
|
|
'label': Const.TARGET, |
|
|
|
} |
|
|
|
CSVLoader.__init__(self, sep='\t') |
|
|
|
|
|
|
|
def _load(self, path): |
|
|
|
ds = CSVLoader._load(self, path) |
|
|
|
|
|
|
|
for k, v in self.fields.items(): |
|
|
|
ds.rename_field(k, v) |
|
|
|
for fields in ds.get_all_fields(): |
|
|
|
if Const.INPUT in fields: |
|
|
|
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) |
|
|
|
|
|
|
|
return ds |
|
|
|
|