diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py new file mode 100644 index 00000000..558fe20e --- /dev/null +++ b/fastNLP/io/dataset_loader.py @@ -0,0 +1,333 @@ +""" +dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , +得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。 +以SNLI数据集为例:: + + loader = SNLILoader() + train_ds = loader.load('path/to/train') + dev_ds = loader.load('path/to/dev') + test_ds = loader.load('path/to/test') + + # ... do stuff + +为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 +""" +__all__ = [ + 'CSVLoader', + 'JsonLoader', + 'ConllLoader', + 'SNLILoader', + 'SSTLoader', + 'PeopleDailyCorpusLoader', + 'Conll2003Loader', +] + +import os +from nltk import Tree +from typing import Union, Dict +from ..core.vocabulary import Vocabulary +from ..core.dataset import DataSet +from ..core.instance import Instance +from .file_reader import _read_csv, _read_json, _read_conll +from .base_loader import DataSetLoader, DataInfo +from .data_loader.sst import SSTLoader +from ..core.const import Const +from ..modules.encoder._bert import BertTokenizer + + +class PeopleDailyCorpusLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader` + + 读取人民日报数据集 + """ + + def __init__(self, pos=True, ner=True): + super(PeopleDailyCorpusLoader, self).__init__() + self.pos = pos + self.ner = ner + + def _load(self, data_path): + with open(data_path, "r", encoding="utf-8") as f: + sents = f.readlines() + examples = [] + for sent in sents: + if len(sent) <= 2: + continue + inside_ne = False + sent_pos_tag = [] + sent_words = [] + sent_ner = [] + words = sent.strip().split()[1:] + for word in words: + if "[" in word and "]" in word: + ner_tag = "U" + print(word) + elif "[" in word: + inside_ne = True + ner_tag = "B" + word = word[1:] + elif "]" in word: + ner_tag = "L" + word = word[:word.index("]")] + if inside_ne is True: + inside_ne = False + else: + raise RuntimeError("only ] appears!") + else: + if inside_ne is True: + ner_tag = "I" + else: + ner_tag = "O" + tmp = word.split("/") + token, pos = tmp[0], tmp[1] + sent_ner.append(ner_tag) + sent_pos_tag.append(pos) + sent_words.append(token) + example = [sent_words] + if self.pos is True: + example.append(sent_pos_tag) + if self.ner is True: + example.append(sent_ner) + examples.append(example) + return self.convert(examples) + + def convert(self, data): + """ + + :param data: python 内置对象 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + data_set = DataSet() + for item in data: + sent_words = item[0] + if self.pos is True and self.ner is True: + instance = Instance( + words=sent_words, pos_tags=item[1], ner=item[2]) + elif self.pos is True: + instance = Instance(words=sent_words, pos_tags=item[1]) + elif self.ner is True: + instance = Instance(words=sent_words, ner=item[1]) + else: + instance = Instance(words=sent_words) + data_set.append(instance) + data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") + return data_set + + +class ConllLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` + + 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 + 该符号在conll 2003中被用为文档分割符。 + + 列号从0开始, 每列对应内容为:: + + Column Type + 0 Document ID + 1 Part number + 2 Word number + 3 Word itself + 4 Part-of-Speech + 5 Parse bit + 6 Predicate lemma + 7 Predicate Frameset ID + 8 Word sense + 9 Speaker/Author + 10 Named Entities + 11:N Predicate Arguments + N Coreference + + :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` + :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` + """ + + def __init__(self, headers, indexes=None, dropna=False): + super(ConllLoader, self).__init__() + if not isinstance(headers, (list, tuple)): + raise TypeError( + 'invalid headers: {}, should be list of strings'.format(headers)) + self.headers = headers + self.dropna = dropna + if indexes is None: + self.indexes = list(range(len(self.headers))) + else: + if len(indexes) != len(headers): + raise ValueError + self.indexes = indexes + + def _load(self, path): + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + +class Conll2003Loader(ConllLoader): + """ + 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader` + + 读取Conll2003数据 + + 关于数据集的更多信息,参考: + https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data + """ + + def __init__(self): + headers = [ + 'tokens', 'pos', 'chunks', 'ner', + ] + super(Conll2003Loader, self).__init__(headers=headers) + + +def _cut_long_sentence(sent, max_sample_length=200): + """ + 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。 + 所以截取的句子可能长于或者短于max_sample_length + + :param sent: str. + :param max_sample_length: int. + :return: list of str. + """ + sent_no_space = sent.replace(' ', '') + cutted_sentence = [] + if len(sent_no_space) > max_sample_length: + parts = sent.strip().split() + new_line = '' + length = 0 + for part in parts: + length += len(part) + new_line += part + ' ' + if length > max_sample_length: + new_line = new_line[:-1] + cutted_sentence.append(new_line) + length = 0 + new_line = '' + if new_line != '': + cutted_sentence.append(new_line[:-1]) + else: + cutted_sentence.append(sent) + return cutted_sentence + + +class JsonLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` + + 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 + + :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name + ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , + `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 + ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + + def __init__(self, fields=None, dropna=False): + super(JsonLoader, self).__init__() + self.dropna = dropna + self.fields = None + self.fields_list = None + if fields: + self.fields = {} + for k, v in fields.items(): + self.fields[k] = k if v is None else v + self.fields_list = list(self.fields.keys()) + + def _load(self, path): + ds = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + ds.append(Instance(**ins)) + return ds + + +class SNLILoader(JsonLoader): + """ + 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` + + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + + def __init__(self): + fields = { + 'sentence1_parse': Const.INPUTS(0), + 'sentence2_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + super(SNLILoader, self).__init__(fields=fields) + + def _load(self, path): + ds = super(SNLILoader, self)._load(path) + + def parse_tree(x): + t = Tree.fromstring(x) + return t.leaves() + + ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + +class CSVLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` + + 读取CSV格式的数据集。返回 ``DataSet`` + + :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 + 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` + :param str sep: CSV文件中列与列之间的分隔符. Default: "," + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + + def __init__(self, headers=None, sep=",", dropna=False): + self.headers = headers + self.sep = sep + self.dropna = dropna + + def _load(self, path): + ds = DataSet() + for idx, data in _read_csv(path, headers=self.headers, + sep=self.sep, dropna=self.dropna): + ds.append(Instance(**data)) + return ds + + +def _add_seg_tag(data): + """ + + :param data: list of ([word], [pos], [heads], [head_tags]) + :return: list of ([word], [pos]) + """ + + _processed = [] + for word_list, pos_list, _, _ in data: + new_sample = [] + for word, pos in zip(word_list, pos_list): + if len(word) == 1: + new_sample.append((word, 'S-' + pos)) + else: + new_sample.append((word[0], 'B-' + pos)) + for c in word[1:-1]: + new_sample.append((c, 'M-' + pos)) + new_sample.append((word[-1], 'E-' + pos)) + _processed.append(list(map(list, zip(*new_sample)))) + return _processed