Browse Source

增加了 DataSet Loader 的文档

tags/v0.4.10
ChenXin 5 years ago
parent
commit
6a8f50e73e
1 changed files with 83 additions and 40 deletions
  1. +83
    -40
      fastNLP/io/dataset_loader.py

+ 83
- 40
fastNLP/io/dataset_loader.py View File

@@ -1,6 +1,6 @@
"""
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` ,
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer`, :class:`~fastNLP.Tester`, 用于模型的训练和测试。
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` :class:`~fastNLP.Tester`, 用于模型的训练和测试。
以SNLI数据集为例::

loader = SNLILoader()
@@ -9,8 +9,11 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的
test_ds = loader.load('path/to/test')

# ... do stuff
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。
"""
__all__ = [
'DataInfo',
'DataSetLoader',
'CSVLoader',
'JsonLoader',
@@ -26,7 +29,7 @@ from nltk.tree import Tree
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union
from typing import Union, Dict
import os


@@ -36,7 +39,7 @@ def _download_from_url(url, path):
except:
from ..core.utils import _pseudo_tqdm as tqdm
import requests
"""Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024
@@ -55,11 +58,11 @@ def _uncompress(src, dst):
import gzip
import tarfile
import os
def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst)
def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
@@ -67,11 +70,11 @@ def _uncompress(src, dst):
while buf:
uf.write(buf)
buf = f.read(length)
def untar(src, dst):
with tarfile.open(src, 'r:gz') as f:
f.extractall(dst)
fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn)
if ext == '.zip':
@@ -84,7 +87,15 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src))


class DataInfo():
class DataInfo:
"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。

:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
"""
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
self.embeddings = embeddings or {}
@@ -95,11 +106,27 @@ class DataSetLoader:
"""
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`

所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
开发者至少应该编写如下内容:
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
**process 函数中可以 调用load 函数或 _load 函数**
"""
def _download(self, url: str, path: str, uncompress=True) -> str:
"""从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。
返回数据的路径。
"""
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。

:param url: 下载的网站
:param path: 下载到的目录
:param uncompress: 是否自动解压缩
:return: 数据的存放路径
"""
pdir = os.path.dirname(path)
os.makedirs(pdir, exist_ok=True)
@@ -109,27 +136,43 @@ class DataSetLoader:
_uncompress(path, dst)
return dst
return path
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
"""
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。

def load(self, paths: Union[str, dict]) -> Union[DataSet, dict]:
"""从指定一个或多个 ``paths`` 的文件中读取数据,返回DataSet

:param str or dict paths: 文件路径
:return: 一个存储 :class:`~fastNLP.DataSet` 的字典
:param Union[str, Dict[str, str]] paths: 文件路径
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
"""
if isinstance(paths, str):
return self._load(paths)
return {name: self._load(path) for name, path in paths.items()}
def _load(self, path: str) -> DataSet:
"""从指定 ``path`` 的文件中读取数据,返回DataSet
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象

:param str path: 文件路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
raise NotImplementedError

def process(self, paths: Union[str, dict], **options) -> Union[DataInfo, dict]:
"""读取并处理数据,返回处理结果
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
"""
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。

返回的 :class:`DataInfo` 对象有如下属性:
- vocabs: 由从数据集中获取的词表组成的字典,每个词表
- embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`

:param paths: 原始数据读取的路径
:param options: 根据不同的任务和数据集,设计自己的参数
:return: 返回一个 DataInfo
"""
raise NotImplementedError

@@ -140,12 +183,12 @@ class PeopleDailyCorpusLoader(DataSetLoader):

读取人民日报数据集
"""
def __init__(self, pos=True, ner=True):
super(PeopleDailyCorpusLoader, self).__init__()
self.pos = pos
self.ner = ner
def _load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines()
@@ -190,7 +233,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
example.append(sent_ner)
examples.append(example)
return self.convert(examples)
def convert(self, data):
"""

@@ -241,7 +284,7 @@ class ConllLoader(DataSetLoader):
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
"""
def __init__(self, headers, indexes=None, dropna=False):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
@@ -255,7 +298,7 @@ class ConllLoader(DataSetLoader):
if len(indexes) != len(headers):
raise ValueError
self.indexes = indexes
def _load(self, path):
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
@@ -273,7 +316,7 @@ class Conll2003Loader(ConllLoader):
关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
"""
def __init__(self):
headers = [
'tokens', 'pos', 'chunks', 'ner',
@@ -325,17 +368,17 @@ class SSTLoader(DataSetLoader):
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
"""
def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v
def _load(self, path):
"""

@@ -352,7 +395,7 @@ class SSTLoader(DataSetLoader):
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds
@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
@@ -374,7 +417,7 @@ class JsonLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""
def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__()
self.dropna = dropna
@@ -385,7 +428,7 @@ class JsonLoader(DataSetLoader):
for k, v in fields.items():
self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys())
def _load(self, path):
ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
@@ -409,7 +452,7 @@ class SNLILoader(JsonLoader):

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""
def __init__(self):
fields = {
'sentence1_parse': 'words1',
@@ -417,14 +460,14 @@ class SNLILoader(JsonLoader):
'gold_label': 'target',
}
super(SNLILoader, self).__init__(fields=fields)
def _load(self, path):
ds = super(SNLILoader, self)._load(path)
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()
ds.apply(lambda ins: parse_tree(
ins['words1']), new_field_name='words1')
ds.apply(lambda ins: parse_tree(
@@ -445,12 +488,12 @@ class CSVLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""
def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers
self.sep = sep
self.dropna = dropna
def _load(self, path):
ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers,
@@ -465,7 +508,7 @@ def _add_seg_tag(data):
:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""
_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []


Loading…
Cancel
Save