diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 28f466a8..05d75f43 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -11,21 +11,35 @@ """ __all__ = [ 'EmbedLoader', - + + 'DataInfo', 'DataSetLoader', + 'CSVLoader', 'JsonLoader', 'ConllLoader', - 'SNLILoader', - 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', 'ModelLoader', 'ModelSaver', + + 'SSTLoader', + + 'MatchingLoader', + 'SNLILoader', + 'MNLILoader', + 'QNLILoader', + 'QuoraLoader', + 'RTELoader', ] from .embed_loader import EmbedLoader -from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ - SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader +from .base_loader import DataInfo, DataSetLoader +from .dataset_loader import CSVLoader, JsonLoader, ConllLoader, \ + PeopleDailyCorpusLoader, Conll2003Loader from .model_io import ModelLoader, ModelSaver + +from .data_loader.sst import SSTLoader +from .data_loader.matching import MatchingLoader, SNLILoader, \ + MNLILoader, QNLILoader, QuoraLoader, RTELoader diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py new file mode 100644 index 00000000..70a683f2 --- /dev/null +++ b/fastNLP/io/data_loader/matching.py @@ -0,0 +1,428 @@ +import os + +from typing import Union, Dict + +from ...core.const import Const +from ...core.vocabulary import Vocabulary +from ...io.base_loader import DataInfo, DataSetLoader +from ...io.dataset_loader import JsonLoader, CSVLoader +from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR +from ...modules.encoder._bert import BertTokenizer + + +class MatchingLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` + + 读取Matching任务的数据集 + + :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 + """ + + def __init__(self, paths: dict = None): + self.paths = paths + + def _load(self, path): + """ + :param str path: 待读取数据集的路径名 + :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 + 的原始字符串文本,第三个为标签 + """ + raise NotImplementedError + + def process(self, paths: Union[str, Dict[str, str]], dataset_name: str = None, + to_lower=False, seq_len_type: str = None, bert_tokenizer: str = None, + cut_text: int = None, get_index=True, auto_pad_length: int = None, + auto_pad_token: str = '', set_input: Union[list, str, bool] = True, + set_target: Union[list, str, bool] = True, concat: Union[str, list, bool] = None, ) -> DataInfo: + """ + :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, + 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 + 对应的全路径文件名。 + :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 + 这个数据集的名字,如果不定义则默认为train。 + :param bool to_lower: 是否将文本自动转为小写。默认值为False。 + :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : + 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 + attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len + :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 + :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 + :param bool get_index: 是否需要根据词表将文本转为index + :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad + :param str auto_pad_token: 自动pad的内容 + :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False + 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, + 于此同时其他field不会被设置为input。默认值为True。 + :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 + :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。 + 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 + 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. + :return: + """ + if isinstance(set_input, str): + set_input = [set_input] + if isinstance(set_target, str): + set_target = [set_target] + if isinstance(set_input, bool): + auto_set_input = set_input + else: + auto_set_input = False + if isinstance(set_target, bool): + auto_set_target = set_target + else: + auto_set_target = False + if isinstance(paths, str): + if os.path.isdir(paths): + path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} + else: + path = {dataset_name if dataset_name is not None else 'train': paths} + else: + path = paths + + data_info = DataInfo() + for data_name in path.keys(): + data_info.datasets[data_name] = self._load(path[data_name]) + + for data_name, data_set in data_info.datasets.items(): + if auto_set_input: + data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) + if auto_set_target: + if Const.TARGET in data_set.get_field_names(): + data_set.set_target(Const.TARGET) + + if to_lower: + for data_name, data_set in data_info.datasets.items(): + data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), + is_input=auto_set_input) + data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), + is_input=auto_set_input) + + if bert_tokenizer is not None: + if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: + PRETRAIN_URL = _get_base_url('bert') + model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] + model_url = PRETRAIN_URL + model_name + model_dir = cached_path(model_url) + # 检查是否存在 + elif os.path.isdir(bert_tokenizer): + model_dir = bert_tokenizer + else: + raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") + + words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') + with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + words_vocab.add_word_lst(lines) + words_vocab.build_vocab() + + tokenizer = BertTokenizer.from_pretrained(model_dir) + + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, + is_input=auto_set_input) + + if isinstance(concat, bool): + concat = 'default' if concat else None + if concat is not None: + if isinstance(concat, str): + CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], + 'default': ['', '', '', '']} + if concat.lower() in CONCAT_MAP: + concat = CONCAT_MAP[concat] + else: + concat = 4 * [concat] + assert len(concat) == 4, \ + f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ + f'the end of first sentence, the begin of second sentence, and the end of second' \ + f'sentence. Your input is {concat}' + + for data_name, data_set in data_info.datasets.items(): + data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + + x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) + data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, + is_input=auto_set_input) + + if seq_len_type is not None: + if seq_len_type == 'seq_len': # + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: len(x[fields]), + new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), + is_input=auto_set_input) + elif seq_len_type == 'mask': + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: [1] * len(x[fields]), + new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), + is_input=auto_set_input) + elif seq_len_type == 'bert': + for data_name, data_set in data_info.datasets.items(): + if Const.INPUT not in data_set.get_field_names(): + raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' + f'got {data_set.get_field_names()}') + data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), + new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) + data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), + new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) + + if auto_pad_length is not None: + cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) + + if cut_text is not None: + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): + data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, + is_input=auto_set_input) + + data_set_list = [d for n, d in data_info.datasets.items()] + assert len(data_set_list) > 0, f'There are NO data sets in data info!' + + if bert_tokenizer is None: + words_vocab = Vocabulary(padding=auto_pad_token) + words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], + field_name=[n for n in data_set_list[0].get_field_names() + if (Const.INPUT in n)], + no_create_entry_dataset=[d for n, d in data_info.datasets.items() + if 'train' not in n]) + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], + field_name=Const.TARGET) + data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} + + if get_index: + for data_name, data_set in data_info.datasets.items(): + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, + is_input=auto_set_input) + + if Const.TARGET in data_set.get_field_names(): + data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, + is_input=auto_set_input, is_target=auto_set_target) + + if auto_pad_length is not None: + for data_name, data_set in data_info.datasets.items(): + if seq_len_type == 'seq_len': + raise RuntimeError(f'sequence will be padded with the length {auto_pad_length},' + f'the seq_len_type cannot be `{seq_len_type}`!') + for fields in data_set.get_field_names(): + if Const.INPUT in fields: + data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * + (auto_pad_length - len(x[fields])), new_field_name=fields, + is_input=auto_set_input) + elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): + data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), + new_field_name=fields, is_input=auto_set_input) + + for data_name, data_set in data_info.datasets.items(): + if isinstance(set_input, list): + data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) + if isinstance(set_target, list): + data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) + + return data_info + + +class SNLILoader(MatchingLoader, JsonLoader): + """ + 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` + + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + + def __init__(self, paths: dict = None): + fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + paths = paths if paths is not None else { + 'train': 'snli_1.0_train.jsonl', + 'dev': 'snli_1.0_dev.jsonl', + 'test': 'snli_1.0_test.jsonl'} + MatchingLoader.__init__(self, paths=paths) + JsonLoader.__init__(self, fields=fields) + + def _load(self, path): + ds = JsonLoader._load(self, path) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + +class RTELoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` + + 读取RTE数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict = None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'sentence1': Const.INPUTS(0), + 'sentence2': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds + + +class QNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` + + 读取QNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict = None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'question': Const.INPUTS(0), + 'sentence': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds + + +class MNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict = None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev_matched': 'dev_matched.tsv', + 'dev_mismatched': 'dev_mismatched.tsv', + 'test_matched': 'test_matched.tsv', + 'test_mismatched': 'test_mismatched.tsv', + # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', + # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', + + # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t') + self.fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) + + if Const.TARGET in ds.get_field_names(): + if ds[0][Const.TARGET] == 'hidden': + ds.delete_field(Const.TARGET) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + if Const.TARGET in ds.get_field_names(): + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + +class QuoraLoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict = None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv', + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) + + def _load(self, path): + ds = CSVLoader._load(self, path) + return ds diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 558fe20e..26edd8bd 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -249,42 +249,6 @@ class JsonLoader(DataSetLoader): return ds -class SNLILoader(JsonLoader): - """ - 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` - - 读取SNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip - """ - - def __init__(self): - fields = { - 'sentence1_parse': Const.INPUTS(0), - 'sentence2_parse': Const.INPUTS(1), - 'gold_label': Const.TARGET, - } - super(SNLILoader, self).__init__(fields=fields) - - def _load(self, path): - ds = super(SNLILoader, self)._load(path) - - def parse_tree(x): - t = Tree.fromstring(x) - return t.leaves() - - ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) - ds.apply(lambda ins: parse_tree( - ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) - ds.drop(lambda x: x[Const.TARGET] == '-') - return ds - - class CSVLoader(DataSetLoader): """ 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 9d948ec1..43f016d6 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -212,12 +212,12 @@ class MatchingLoader(DataSetLoader): for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: - data_set.apply(lambda x: x[fields] + [words_vocab.padding] * (auto_pad_length - len(x[fields])), - new_field_name=fields, is_input=auto_set_input) - elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * (auto_pad_length - len(x[fields])), new_field_name=fields, is_input=auto_set_input) + elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): + data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), + new_field_name=fields, is_input=auto_set_input) for data_name, data_set in data_info.datasets.items(): if isinstance(set_input, list): @@ -284,7 +284,7 @@ class RTELoader(MatchingLoader, CSVLoader): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', - # 'test': 'test.tsv' # test set has not label + 'test': 'test.tsv' # test set has not label } MatchingLoader.__init__(self, paths=paths) self.fields = { @@ -323,7 +323,7 @@ class QNLILoader(MatchingLoader, CSVLoader): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', - # 'test': 'test.tsv' # test set has not label + 'test': 'test.tsv' # test set has not label } MatchingLoader.__init__(self, paths=paths) self.fields = { @@ -367,6 +367,7 @@ class MNLILoader(MatchingLoader, CSVLoader): 'test_mismatched': 'test_mismatched.tsv', # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', + # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) } MatchingLoader.__init__(self, paths=paths)