Browse Source

move sst loader to new folder

tags/v0.4.10
yunfan 5 years ago
parent
commit
7fea175f0a
3 changed files with 258 additions and 277 deletions
  1. +164
    -23
      fastNLP/io/base_loader.py
  2. +91
    -0
      fastNLP/io/data_loader/sst.py
  3. +3
    -254
      fastNLP/io/dataset_loader.py

+ 164
- 23
fastNLP/io/base_loader.py View File

@@ -1,10 +1,14 @@
__all__ = [
"BaseLoader"
"BaseLoader",
'DataInfo',
'DataSetLoader',
]

import _pickle as pickle
import os

from typing import Union, Dict
import os
from ..core.dataset import DataSet

class BaseLoader(object):
"""
@@ -51,24 +55,161 @@ class BaseLoader(object):
return obj


class DataLoaderRegister:
_readers = {}
@classmethod
def set_reader(cls, reader_cls, read_fn_name):
# def wrapper(reader_cls):
if read_fn_name in cls._readers:
raise KeyError(
'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls,
read_fn_name))
if hasattr(reader_cls, 'load'):
cls._readers[read_fn_name] = reader_cls().load
return reader_cls
@classmethod
def get_reader(cls, read_fn_name):
if read_fn_name in cls._readers:
return cls._readers[read_fn_name]
raise AttributeError('no read function: {}'.format(read_fn_name))
# TODO 这个类使用在何处?


def _download_from_url(url, path):
try:
from tqdm.auto import tqdm
except:
from ..core.utils import _pseudo_tqdm as tqdm
import requests

"""Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0))
with open(path, "wb") as file, \
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
for chunk in r.iter_content(chunk_size):
if chunk:
file.write(chunk)
t.update(len(chunk))


def _uncompress(src, dst):
import zipfile
import gzip
import tarfile
import os

def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst)

def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
buf = f.read(length)
while buf:
uf.write(buf)
buf = f.read(length)

def untar(src, dst):
with tarfile.open(src, 'r:gz') as f:
f.extractall(dst)

fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn)
if ext == '.zip':
unzip(src, dst)
elif ext == '.gz' and ext_2 != '.tar':
ungz(src, dst)
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
untar(src, dst)
else:
raise ValueError('unsupported file {}'.format(src))


class DataInfo:
"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。

:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
"""

def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
self.embeddings = embeddings or {}
self.datasets = datasets or {}


class DataSetLoader:
"""
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`

定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。

开发者至少应该编写如下内容:

- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`

**process 函数中可以 调用load 函数或 _load 函数**

"""
URL = ''
DATA_DIR = ''

ROOT_DIR = '.fastnlp/datasets/'
UNCOMPRESS = True

def _download(self, url: str, pdir: str, uncompress=True) -> str:
"""

从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。

:param url: 下载的网站
:param pdir: 下载到的目录
:param uncompress: 是否自动解压缩
:return: 数据的存放路径
"""
fn = os.path.basename(url)
path = os.path.join(pdir, fn)
"""check data exists"""
if not os.path.exists(path):
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
if uncompress:
dst = os.path.join(pdir, 'data')
if not os.path.exists(dst):
_uncompress(path, dst)
return dst
return path

def download(self):
return self._download(
self.URL,
os.path.join(self.ROOT_DIR, self.DATA_DIR),
uncompress=self.UNCOMPRESS)

def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
"""
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。

:param Union[str, Dict[str, str]] paths: 文件路径
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
"""
if isinstance(paths, str):
return self._load(paths)
return {name: self._load(path) for name, path in paths.items()}

def _load(self, path: str) -> DataSet:
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象

:param str path: 文件路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
"""
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。

从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。

返回的 :class:`DataInfo` 对象有如下属性:

- vocabs: 由从数据集中获取的词表组成的字典,每个词表
- embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`

:param paths: 原始数据读取的路径
:param options: 根据不同的任务和数据集,设计自己的参数
:return: 返回一个 DataInfo
"""
raise NotImplementedError

+ 91
- 0
fastNLP/io/data_loader/sst.py View File

@@ -0,0 +1,91 @@
from typing import Iterable
from nltk import Tree
from ..base_loader import DataInfo, DataSetLoader
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.instance import Instance
from ..embed_loader import EmbeddingOption, EmbedLoader


class SSTLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`

读取SST数据集, DataSet包含fields::

words: list(str) 需要分类的文本
target: str 文本的标签

数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip

:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
"""

def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree

tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v

def _load(self, path):
"""

:param str path: 存储数据的路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
datas = []
for l in f:
datas.extend([(s, self.tag_v[t])
for s, t in self._get_one(l, self.subtree)])
ds = DataSet()
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds

@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]

def process(self,
paths,
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
src_embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()
src_vocab.from_dataset(*_train_ds, field_name=input_name)
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
src_vocab.index_dataset(
*info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
*info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,
target_name: tgt_vocab
}

if src_embed_op is not None:
src_embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
info.embeddings[input_name] = init_emb

return info


+ 3
- 254
fastNLP/io/dataset_loader.py View File

@@ -13,8 +13,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。
"""
__all__ = [
'DataInfo',
'DataSetLoader',
'CSVLoader',
'JsonLoader',
'ConllLoader',
@@ -24,178 +22,12 @@ __all__ = [
'Conll2003Loader',
]

from nltk.tree import Tree

from nltk import Tree
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union, Dict, Iterable
import os
from ..core.utils import Example
from ..core.vocabulary import Vocabulary
from ..io.embed_loader import EmbedLoader
import numpy as np
from ..core.vocabulary import VocabularyOption
from .embed_loader import EmbeddingOption


def _download_from_url(url, path):
try:
from tqdm.auto import tqdm
except:
from ..core.utils import _pseudo_tqdm as tqdm
import requests

"""Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0))
with open(path, "wb") as file, \
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
for chunk in r.iter_content(chunk_size):
if chunk:
file.write(chunk)
t.update(len(chunk))


def _uncompress(src, dst):
import zipfile
import gzip
import tarfile
import os

def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst)

def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
buf = f.read(length)
while buf:
uf.write(buf)
buf = f.read(length)

def untar(src, dst):
with tarfile.open(src, 'r:gz') as f:
f.extractall(dst)

fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn)
if ext == '.zip':
unzip(src, dst)
elif ext == '.gz' and ext_2 != '.tar':
ungz(src, dst)
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
untar(src, dst)
else:
raise ValueError('unsupported file {}'.format(src))


class DataInfo:
"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。

:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
"""

def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
self.embeddings = embeddings or {}
self.datasets = datasets or {}


class DataSetLoader:
"""
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`

定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。

开发者至少应该编写如下内容:

- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`

**process 函数中可以 调用load 函数或 _load 函数**

"""
URL = ''
DATA_DIR = ''

ROOT_DIR = '.fastnlp/datasets/'
UNCOMPRESS = True

def _download(self, url: str, pdir: str, uncompress=True) -> str:
"""

从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。

:param url: 下载的网站
:param pdir: 下载到的目录
:param uncompress: 是否自动解压缩
:return: 数据的存放路径
"""
fn = os.path.basename(url)
path = os.path.join(pdir, fn)
"""check data exists"""
if not os.path.exists(path):
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
if uncompress:
dst = os.path.join(pdir, 'data')
if not os.path.exists(dst):
_uncompress(path, dst)
return dst
return path

def download(self):
return self._download(
self.URL,
os.path.join(self.ROOT_DIR, self.DATA_DIR),
uncompress=self.UNCOMPRESS)

def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
"""
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。

:param Union[str, Dict[str, str]] paths: 文件路径
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
"""
if isinstance(paths, str):
return self._load(paths)
return {name: self._load(path) for name, path in paths.items()}

def _load(self, path: str) -> DataSet:
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象

:param str path: 文件路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
"""
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。

从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。

返回的 :class:`DataInfo` 对象有如下属性:

- vocabs: 由从数据集中获取的词表组成的字典,每个词表
- embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`

:param paths: 原始数据读取的路径
:param options: 根据不同的任务和数据集,设计自己的参数
:return: 返回一个 DataInfo
"""
raise NotImplementedError

from .base_loader import DataSetLoader
from .data_loader.sst import SSTLoader

class PeopleDailyCorpusLoader(DataSetLoader):
"""
@@ -374,89 +206,6 @@ def _cut_long_sentence(sent, max_sample_length=200):
return cutted_sentence


class SSTLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`

读取SST数据集, DataSet包含fields::

words: list(str) 需要分类的文本
target: str 文本的标签

数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip

:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
"""

def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree

tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v

def _load(self, path):
"""

:param str path: 存储数据的路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
datas = []
for l in f:
datas.extend([(s, self.tag_v[t])
for s, t in self._get_one(l, self.subtree)])
ds = DataSet()
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds

@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]

def process(self,
paths,
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
src_embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()
src_vocab.from_dataset(*_train_ds, field_name=input_name)
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
src_vocab.index_dataset(
*info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
*info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,
target_name: tgt_vocab
}

if src_embed_op is not None:
src_embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
info.embeddings[input_name] = init_emb

return info


class JsonLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader`


Loading…
Cancel
Save