From 6a8f50e73e81a5175eb5bd9d3d24ba615ce7d901 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Wed, 22 May 2019 15:06:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=20DataSet=20Loader?= =?UTF-8?q?=20=E7=9A=84=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/dataset_loader.py | 123 +++++++++++++++++++++++------------ 1 file changed, 83 insertions(+), 40 deletions(-) diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 19600ef7..32cca88f 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -1,6 +1,6 @@ """ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , -得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer`, :class:`~fastNLP.Tester`, 用于模型的训练和测试。 +得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。 以SNLI数据集为例:: loader = SNLILoader() @@ -9,8 +9,11 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 test_ds = loader.load('path/to/test') # ... do stuff + +为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 """ __all__ = [ + 'DataInfo', 'DataSetLoader', 'CSVLoader', 'JsonLoader', @@ -26,7 +29,7 @@ from nltk.tree import Tree from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll -from typing import Union +from typing import Union, Dict import os @@ -36,7 +39,7 @@ def _download_from_url(url, path): except: from ..core.utils import _pseudo_tqdm as tqdm import requests - + """Download file""" r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) chunk_size = 16 * 1024 @@ -55,11 +58,11 @@ def _uncompress(src, dst): import gzip import tarfile import os - + def unzip(src, dst): with zipfile.ZipFile(src, 'r') as f: f.extractall(dst) - + def ungz(src, dst): with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: length = 16 * 1024 # 16KB @@ -67,11 +70,11 @@ def _uncompress(src, dst): while buf: uf.write(buf) buf = f.read(length) - + def untar(src, dst): with tarfile.open(src, 'r:gz') as f: f.extractall(dst) - + fn, ext = os.path.splitext(src) _, ext_2 = os.path.splitext(fn) if ext == '.zip': @@ -84,7 +87,15 @@ def _uncompress(src, dst): raise ValueError('unsupported file {}'.format(src)) -class DataInfo(): +class DataInfo: + """ + 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 + + :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict + :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` + :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict + """ + def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): self.vocabs = vocabs or {} self.embeddings = embeddings or {} @@ -95,11 +106,27 @@ class DataSetLoader: """ 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` - 所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader + 定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 + + 开发者至少应该编写如下内容: + + - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` + - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` + - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` + + **process 函数中可以 调用load 函数或 _load 函数** + """ + def _download(self, url: str, path: str, uncompress=True) -> str: - """从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 - 返回数据的路径。 + """ + + 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 + + :param url: 下载的网站 + :param path: 下载到的目录 + :param uncompress: 是否自动解压缩 + :return: 数据的存放路径 """ pdir = os.path.dirname(path) os.makedirs(pdir, exist_ok=True) @@ -109,27 +136,43 @@ class DataSetLoader: _uncompress(path, dst) return dst return path + + def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: + """ + 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 + 如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 - def load(self, paths: Union[str, dict]) -> Union[DataSet, dict]: - """从指定一个或多个 ``paths`` 的文件中读取数据,返回DataSet - - :param str or dict paths: 文件路径 - :return: 一个存储 :class:`~fastNLP.DataSet` 的字典 + :param Union[str, Dict[str, str]] paths: 文件路径 + :return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 """ if isinstance(paths, str): return self._load(paths) return {name: self._load(path) for name, path in paths.items()} - + def _load(self, path: str) -> DataSet: - """从指定 ``path`` 的文件中读取数据,返回DataSet + """从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 :param str path: 文件路径 :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 """ raise NotImplementedError - - def process(self, paths: Union[str, dict], **options) -> Union[DataInfo, dict]: - """读取并处理数据,返回处理结果 + + def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: + """ + 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 + + 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 + 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 + + 返回的 :class:`DataInfo` 对象有如下属性: + + - vocabs: 由从数据集中获取的词表组成的字典,每个词表 + - embeddings: (可选) 数据集对应的词嵌入 + - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` + + :param paths: 原始数据读取的路径 + :param options: 根据不同的任务和数据集,设计自己的参数 + :return: 返回一个 DataInfo """ raise NotImplementedError @@ -140,12 +183,12 @@ class PeopleDailyCorpusLoader(DataSetLoader): 读取人民日报数据集 """ - + def __init__(self, pos=True, ner=True): super(PeopleDailyCorpusLoader, self).__init__() self.pos = pos self.ner = ner - + def _load(self, data_path): with open(data_path, "r", encoding="utf-8") as f: sents = f.readlines() @@ -190,7 +233,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): example.append(sent_ner) examples.append(example) return self.convert(examples) - + def convert(self, data): """ @@ -241,7 +284,7 @@ class ConllLoader(DataSetLoader): :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` """ - + def __init__(self, headers, indexes=None, dropna=False): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): @@ -255,7 +298,7 @@ class ConllLoader(DataSetLoader): if len(indexes) != len(headers): raise ValueError self.indexes = indexes - + def _load(self, path): ds = DataSet() for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): @@ -273,7 +316,7 @@ class Conll2003Loader(ConllLoader): 关于数据集的更多信息,参考: https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ - + def __init__(self): headers = [ 'tokens', 'pos', 'chunks', 'ner', @@ -325,17 +368,17 @@ class SSTLoader(DataSetLoader): :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - + def __init__(self, subtree=False, fine_grained=False): self.subtree = subtree - + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', '3': 'positive', '4': 'very positive'} if not fine_grained: tag_v['0'] = tag_v['1'] tag_v['4'] = tag_v['3'] self.tag_v = tag_v - + def _load(self, path): """ @@ -352,7 +395,7 @@ class SSTLoader(DataSetLoader): for words, tag in datas: ds.append(Instance(words=words, target=tag)) return ds - + @staticmethod def _get_one(data, subtree): tree = Tree.fromstring(data) @@ -374,7 +417,7 @@ class JsonLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, fields=None, dropna=False): super(JsonLoader, self).__init__() self.dropna = dropna @@ -385,7 +428,7 @@ class JsonLoader(DataSetLoader): for k, v in fields.items(): self.fields[k] = k if v is None else v self.fields_list = list(self.fields.keys()) - + def _load(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): @@ -409,7 +452,7 @@ class SNLILoader(JsonLoader): 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip """ - + def __init__(self): fields = { 'sentence1_parse': 'words1', @@ -417,14 +460,14 @@ class SNLILoader(JsonLoader): 'gold_label': 'target', } super(SNLILoader, self).__init__(fields=fields) - + def _load(self, path): ds = super(SNLILoader, self)._load(path) - + def parse_tree(x): t = Tree.fromstring(x) return t.leaves() - + ds.apply(lambda ins: parse_tree( ins['words1']), new_field_name='words1') ds.apply(lambda ins: parse_tree( @@ -445,12 +488,12 @@ class CSVLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, headers=None, sep=",", dropna=False): self.headers = headers self.sep = sep self.dropna = dropna - + def _load(self, path): ds = DataSet() for idx, data in _read_csv(path, headers=self.headers, @@ -465,7 +508,7 @@ def _add_seg_tag(data): :param data: list of ([word], [pos], [heads], [head_tags]) :return: list of ([word], [pos]) """ - + _processed = [] for word_list, pos_list, _, _ in data: new_sample = []