Browse Source

Merge branch 'master' into pr

tags/v0.4.10
yunfan 5 years ago
parent
commit
472b6885a3
2 changed files with 132 additions and 39 deletions
  1. +28
    -1
      fastNLP/core/utils.py
  2. +104
    -38
      fastNLP/io/dataset_loader.py

+ 28
- 1
fastNLP/core/utils.py View File

@@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户
"""
__all__ = [
"cache_results",
"seq_len_to_mask"
"seq_len_to_mask",
"Example",
]

import _pickle
@@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs'])


class Example(dict):
"""a dict can treat keys as attributes"""
def __getattr__(self, item):
try:
return self.__getitem__(item)
except KeyError:
raise AttributeError(item)

def __setattr__(self, key, value):
if key.startswith('__') and key.endswith('__'):
raise AttributeError(key)
self.__setitem__(key, value)

def __delattr__(self, item):
try:
self.pop(item)
except KeyError:
raise AttributeError(item)

def __getstate__(self):
return self

def __setstate__(self, state):
self.update(state)


def _prepare_cache_filepath(filepath):
"""
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径


+ 104
- 38
fastNLP/io/dataset_loader.py View File

@@ -29,8 +29,12 @@ from nltk.tree import Tree
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union, Dict
from typing import Union, Dict, Iterable
import os
from ..core.utils import Example
from ..core import Vocabulary
from ..io import EmbedLoader
import numpy as np


def _download_from_url(url, path):
@@ -39,7 +43,7 @@ def _download_from_url(url, path):
except:
from ..core.utils import _pseudo_tqdm as tqdm
import requests
"""Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024
@@ -58,11 +62,11 @@ def _uncompress(src, dst):
import gzip
import tarfile
import os
def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst)
def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
@@ -70,11 +74,11 @@ def _uncompress(src, dst):
while buf:
uf.write(buf)
buf = f.read(length)
def untar(src, dst):
with tarfile.open(src, 'r:gz') as f:
f.extractall(dst)
fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn)
if ext == '.zip':
@@ -87,6 +91,34 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src))


class VocabularyOption(Example):
def __init__(self,
max_size=None,
min_freq=None,
padding='<pad>',
unknown='<unk>'):
super().__init__(
max_size=max_size,
min_freq=min_freq,
padding=padding,
unknown=unknown
)


class EmbeddingOption(Example):
def __init__(self,
embed_filepath=None,
dtype=np.float32,
normalize=True,
error='ignore'):
super().__init__(
embed_filepath=embed_filepath,
dtype=dtype,
normalize=normalize,
error=error
)


class DataInfo:
"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
@@ -95,7 +127,7 @@ class DataInfo:
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
"""
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
self.embeddings = embeddings or {}
@@ -106,21 +138,21 @@ class DataSetLoader:
"""
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`

定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
开发者至少应该编写如下内容:
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
**process 函数中可以 调用load 函数或 _load 函数**
"""
def _download(self, url: str, path: str, uncompress=True) -> str:
"""
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。

:param url: 下载的网站
@@ -136,7 +168,7 @@ class DataSetLoader:
_uncompress(path, dst)
return dst
return path
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
"""
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
@@ -148,7 +180,7 @@ class DataSetLoader:
if isinstance(paths, str):
return self._load(paths)
return {name: self._load(path) for name, path in paths.items()}
def _load(self, path: str) -> DataSet:
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象

@@ -156,16 +188,16 @@ class DataSetLoader:
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
raise NotImplementedError
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
"""
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。

返回的 :class:`DataInfo` 对象有如下属性:
- vocabs: 由从数据集中获取的词表组成的字典,每个词表
- embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
@@ -183,12 +215,12 @@ class PeopleDailyCorpusLoader(DataSetLoader):

读取人民日报数据集
"""
def __init__(self, pos=True, ner=True):
super(PeopleDailyCorpusLoader, self).__init__()
self.pos = pos
self.ner = ner
def _load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines()
@@ -233,7 +265,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
example.append(sent_ner)
examples.append(example)
return self.convert(examples)
def convert(self, data):
"""

@@ -284,7 +316,7 @@ class ConllLoader(DataSetLoader):
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
"""
def __init__(self, headers, indexes=None, dropna=False):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
@@ -298,7 +330,7 @@ class ConllLoader(DataSetLoader):
if len(indexes) != len(headers):
raise ValueError
self.indexes = indexes
def _load(self, path):
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
@@ -316,7 +348,7 @@ class Conll2003Loader(ConllLoader):
关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
"""
def __init__(self):
headers = [
'tokens', 'pos', 'chunks', 'ner',
@@ -368,17 +400,17 @@ class SSTLoader(DataSetLoader):
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
"""
def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v
def _load(self, path):
"""

@@ -395,7 +427,7 @@ class SSTLoader(DataSetLoader):
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds
@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
@@ -403,6 +435,40 @@ class SSTLoader(DataSetLoader):
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]

def process(self,
paths,
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()

src_vocab.from_dataset(_train_ds, field_name=input_name)
tgt_vocab.from_dataset(_train_ds, field_name=target_name)
src_vocab.index_dataset(
info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,
target_name: tgt_vocab
}

if embed_op is not None:
embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**embed_op)
info.embeddings[input_name] = init_emb

return info


class JsonLoader(DataSetLoader):
"""
@@ -417,7 +483,7 @@ class JsonLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""
def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__()
self.dropna = dropna
@@ -428,7 +494,7 @@ class JsonLoader(DataSetLoader):
for k, v in fields.items():
self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys())
def _load(self, path):
ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
@@ -452,7 +518,7 @@ class SNLILoader(JsonLoader):

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""
def __init__(self):
fields = {
'sentence1_parse': 'words1',
@@ -460,14 +526,14 @@ class SNLILoader(JsonLoader):
'gold_label': 'target',
}
super(SNLILoader, self).__init__(fields=fields)
def _load(self, path):
ds = super(SNLILoader, self)._load(path)
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()
ds.apply(lambda ins: parse_tree(
ins['words1']), new_field_name='words1')
ds.apply(lambda ins: parse_tree(
@@ -488,12 +554,12 @@ class CSVLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""
def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers
self.sep = sep
self.dropna = dropna
def _load(self, path):
ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers,
@@ -508,7 +574,7 @@ def _add_seg_tag(data):
:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""
_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []


Loading…
Cancel
Save