Browse Source

- add process for SSTLoader

tags/v0.4.10
yunfan 5 years ago
parent
commit
f5a005358c
2 changed files with 132 additions and 39 deletions
  1. +28
    -1
      fastNLP/core/utils.py
  2. +104
    -38
      fastNLP/io/dataset_loader.py

+ 28
- 1
fastNLP/core/utils.py View File

@@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户
""" """
__all__ = [ __all__ = [
"cache_results", "cache_results",
"seq_len_to_mask"
"seq_len_to_mask",
"Example",
] ]


import _pickle import _pickle
@@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs']) 'varargs'])




class Example(dict):
"""a dict can treat keys as attributes"""
def __getattr__(self, item):
try:
return self.__getitem__(item)
except KeyError:
raise AttributeError(item)

def __setattr__(self, key, value):
if key.startswith('__') and key.endswith('__'):
raise AttributeError(key)
self.__setitem__(key, value)

def __delattr__(self, item):
try:
self.pop(item)
except KeyError:
raise AttributeError(item)

def __getstate__(self):
return self

def __setstate__(self, state):
self.update(state)


def _prepare_cache_filepath(filepath): def _prepare_cache_filepath(filepath):
""" """
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径


+ 104
- 38
fastNLP/io/dataset_loader.py View File

@@ -29,8 +29,12 @@ from nltk.tree import Tree
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union, Dict
from typing import Union, Dict, Iterable
import os import os
from ..core.utils import Example
from ..core import Vocabulary
from ..io import EmbedLoader
import numpy as np




def _download_from_url(url, path): def _download_from_url(url, path):
@@ -39,7 +43,7 @@ def _download_from_url(url, path):
except: except:
from ..core.utils import _pseudo_tqdm as tqdm from ..core.utils import _pseudo_tqdm as tqdm
import requests import requests
"""Download file""" """Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024 chunk_size = 16 * 1024
@@ -58,11 +62,11 @@ def _uncompress(src, dst):
import gzip import gzip
import tarfile import tarfile
import os import os
def unzip(src, dst): def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f: with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst) f.extractall(dst)
def ungz(src, dst): def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB length = 16 * 1024 # 16KB
@@ -70,11 +74,11 @@ def _uncompress(src, dst):
while buf: while buf:
uf.write(buf) uf.write(buf)
buf = f.read(length) buf = f.read(length)
def untar(src, dst): def untar(src, dst):
with tarfile.open(src, 'r:gz') as f: with tarfile.open(src, 'r:gz') as f:
f.extractall(dst) f.extractall(dst)
fn, ext = os.path.splitext(src) fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn) _, ext_2 = os.path.splitext(fn)
if ext == '.zip': if ext == '.zip':
@@ -87,6 +91,34 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src)) raise ValueError('unsupported file {}'.format(src))




class VocabularyOption(Example):
def __init__(self,
max_size=None,
min_freq=None,
padding='<pad>',
unknown='<unk>'):
super().__init__(
max_size=max_size,
min_freq=min_freq,
padding=padding,
unknown=unknown
)


class EmbeddingOption(Example):
def __init__(self,
embed_filepath=None,
dtype=np.float32,
normalize=True,
error='ignore'):
super().__init__(
embed_filepath=embed_filepath,
dtype=dtype,
normalize=normalize,
error=error
)


class DataInfo: class DataInfo:
""" """
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
@@ -95,7 +127,7 @@ class DataInfo:
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
""" """
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {} self.vocabs = vocabs or {}
self.embeddings = embeddings or {} self.embeddings = embeddings or {}
@@ -106,21 +138,21 @@ class DataSetLoader:
""" """
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`


定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
开发者至少应该编写如下内容: 开发者至少应该编写如下内容:
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
**process 函数中可以 调用load 函数或 _load 函数** **process 函数中可以 调用load 函数或 _load 函数**
""" """
def _download(self, url: str, path: str, uncompress=True) -> str: def _download(self, url: str, path: str, uncompress=True) -> str:
""" """
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。


:param url: 下载的网站 :param url: 下载的网站
@@ -136,7 +168,7 @@ class DataSetLoader:
_uncompress(path, dst) _uncompress(path, dst)
return dst return dst
return path return path
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
""" """
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
@@ -148,7 +180,7 @@ class DataSetLoader:
if isinstance(paths, str): if isinstance(paths, str):
return self._load(paths) return self._load(paths)
return {name: self._load(path) for name, path in paths.items()} return {name: self._load(path) for name, path in paths.items()}
def _load(self, path: str) -> DataSet: def _load(self, path: str) -> DataSet:
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 """从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象


@@ -156,16 +188,16 @@ class DataSetLoader:
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
""" """
raise NotImplementedError raise NotImplementedError
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
""" """
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。


返回的 :class:`DataInfo` 对象有如下属性: 返回的 :class:`DataInfo` 对象有如下属性:
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 - vocabs: 由从数据集中获取的词表组成的字典,每个词表
- embeddings: (可选) 数据集对应的词嵌入 - embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
@@ -183,12 +215,12 @@ class PeopleDailyCorpusLoader(DataSetLoader):


读取人民日报数据集 读取人民日报数据集
""" """
def __init__(self, pos=True, ner=True): def __init__(self, pos=True, ner=True):
super(PeopleDailyCorpusLoader, self).__init__() super(PeopleDailyCorpusLoader, self).__init__()
self.pos = pos self.pos = pos
self.ner = ner self.ner = ner
def _load(self, data_path): def _load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines() sents = f.readlines()
@@ -233,7 +265,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
example.append(sent_ner) example.append(sent_ner)
examples.append(example) examples.append(example)
return self.convert(examples) return self.convert(examples)
def convert(self, data): def convert(self, data):
""" """


@@ -284,7 +316,7 @@ class ConllLoader(DataSetLoader):
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
""" """
def __init__(self, headers, indexes=None, dropna=False): def __init__(self, headers, indexes=None, dropna=False):
super(ConllLoader, self).__init__() super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)): if not isinstance(headers, (list, tuple)):
@@ -298,7 +330,7 @@ class ConllLoader(DataSetLoader):
if len(indexes) != len(headers): if len(indexes) != len(headers):
raise ValueError raise ValueError
self.indexes = indexes self.indexes = indexes
def _load(self, path): def _load(self, path):
ds = DataSet() ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
@@ -316,7 +348,7 @@ class Conll2003Loader(ConllLoader):
关于数据集的更多信息,参考: 关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
""" """
def __init__(self): def __init__(self):
headers = [ headers = [
'tokens', 'pos', 'chunks', 'ner', 'tokens', 'pos', 'chunks', 'ner',
@@ -368,17 +400,17 @@ class SSTLoader(DataSetLoader):
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
""" """
def __init__(self, subtree=False, fine_grained=False): def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree self.subtree = subtree
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'} '3': 'positive', '4': 'very positive'}
if not fine_grained: if not fine_grained:
tag_v['0'] = tag_v['1'] tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3'] tag_v['4'] = tag_v['3']
self.tag_v = tag_v self.tag_v = tag_v
def _load(self, path): def _load(self, path):
""" """


@@ -395,7 +427,7 @@ class SSTLoader(DataSetLoader):
for words, tag in datas: for words, tag in datas:
ds.append(Instance(words=words, target=tag)) ds.append(Instance(words=words, target=tag))
return ds return ds
@staticmethod @staticmethod
def _get_one(data, subtree): def _get_one(data, subtree):
tree = Tree.fromstring(data) tree = Tree.fromstring(data)
@@ -403,6 +435,40 @@ class SSTLoader(DataSetLoader):
return [(t.leaves(), t.label()) for t in tree.subtrees()] return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())] return [(tree.leaves(), tree.label())]


def process(self,
paths,
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()

src_vocab.from_dataset(_train_ds, field_name=input_name)
tgt_vocab.from_dataset(_train_ds, field_name=target_name)
src_vocab.index_dataset(
info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,
target_name: tgt_vocab
}

if embed_op is not None:
embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**embed_op)
info.embeddings[input_name] = init_emb

return info



class JsonLoader(DataSetLoader): class JsonLoader(DataSetLoader):
""" """
@@ -417,7 +483,7 @@ class JsonLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False`` Default: ``False``
""" """
def __init__(self, fields=None, dropna=False): def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__() super(JsonLoader, self).__init__()
self.dropna = dropna self.dropna = dropna
@@ -428,7 +494,7 @@ class JsonLoader(DataSetLoader):
for k, v in fields.items(): for k, v in fields.items():
self.fields[k] = k if v is None else v self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys()) self.fields_list = list(self.fields.keys())
def _load(self, path): def _load(self, path):
ds = DataSet() ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
@@ -452,7 +518,7 @@ class SNLILoader(JsonLoader):


数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
""" """
def __init__(self): def __init__(self):
fields = { fields = {
'sentence1_parse': 'words1', 'sentence1_parse': 'words1',
@@ -460,14 +526,14 @@ class SNLILoader(JsonLoader):
'gold_label': 'target', 'gold_label': 'target',
} }
super(SNLILoader, self).__init__(fields=fields) super(SNLILoader, self).__init__(fields=fields)
def _load(self, path): def _load(self, path):
ds = super(SNLILoader, self)._load(path) ds = super(SNLILoader, self)._load(path)
def parse_tree(x): def parse_tree(x):
t = Tree.fromstring(x) t = Tree.fromstring(x)
return t.leaves() return t.leaves()
ds.apply(lambda ins: parse_tree( ds.apply(lambda ins: parse_tree(
ins['words1']), new_field_name='words1') ins['words1']), new_field_name='words1')
ds.apply(lambda ins: parse_tree( ds.apply(lambda ins: parse_tree(
@@ -488,12 +554,12 @@ class CSVLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False`` Default: ``False``
""" """
def __init__(self, headers=None, sep=",", dropna=False): def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers self.headers = headers
self.sep = sep self.sep = sep
self.dropna = dropna self.dropna = dropna
def _load(self, path): def _load(self, path):
ds = DataSet() ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers, for idx, data in _read_csv(path, headers=self.headers,
@@ -508,7 +574,7 @@ def _add_seg_tag(data):
:param data: list of ([word], [pos], [heads], [head_tags]) :param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos]) :return: list of ([word], [pos])
""" """
_processed = [] _processed = []
for word_list, pos_list, _, _ in data: for word_list, pos_list, _, _ in data:
new_sample = [] new_sample = []


Loading…
Cancel
Save