Browse Source

- update dataset_loader

tags/v0.4.10
yunfan 5 years ago
parent
commit
6862a8f169
1 changed files with 83 additions and 57 deletions
  1. +83
    -57
      fastNLP/io/dataset_loader.py

+ 83
- 57
fastNLP/io/dataset_loader.py View File

@@ -26,6 +26,8 @@ from nltk.tree import Tree
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union
import os




def _download_from_url(url, path): def _download_from_url(url, path):
@@ -34,7 +36,7 @@ def _download_from_url(url, path):
except: except:
from ..core.utils import _pseudo_tqdm as tqdm from ..core.utils import _pseudo_tqdm as tqdm
import requests import requests
"""Download file""" """Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024 chunk_size = 16 * 1024
@@ -49,12 +51,15 @@ def _download_from_url(url, path):




def _uncompress(src, dst): def _uncompress(src, dst):
import zipfile, gzip, tarfile, os
import zipfile
import gzip
import tarfile
import os

def unzip(src, dst): def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f: with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst) f.extractall(dst)
def ungz(src, dst): def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB length = 16 * 1024 # 16KB
@@ -62,11 +67,11 @@ def _uncompress(src, dst):
while buf: while buf:
uf.write(buf) uf.write(buf)
buf = f.read(length) buf = f.read(length)
def untar(src, dst): def untar(src, dst):
with tarfile.open(src, 'r:gz') as f: with tarfile.open(src, 'r:gz') as f:
f.extractall(dst) f.extractall(dst)
fn, ext = os.path.splitext(src) fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn) _, ext_2 = os.path.splitext(fn)
if ext == '.zip': if ext == '.zip':
@@ -79,27 +84,52 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src)) raise ValueError('unsupported file {}'.format(src))




class DataInfo():
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
self.embeddings = embeddings or {}
self.datasets = datasets or {}


class DataSetLoader: class DataSetLoader:
""" """
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`


所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader 所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader
""" """
def load(self, path):
def _download(self, url: str, path: str, uncompress=True) -> str:
"""从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。
返回数据的路径。
"""
pdir = os.path.dirname(path)
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
if uncompress:
dst = os.path.join(pdir, 'data')
_uncompress(path, dst)
return dst
return path

def load(self, paths: Union[str, dict]) -> Union[DataSet, dict]:
"""从指定一个或多个 ``paths`` 的文件中读取数据,返回DataSet

:param str or dict paths: 文件路径
:return: 一个存储 :class:`~fastNLP.DataSet` 的字典
"""
if isinstance(paths, str):
return self._load(paths)
return {name: self._load(path) for name, path in paths.items()}

def _load(self, path: str) -> DataSet:
"""从指定 ``path`` 的文件中读取数据,返回DataSet """从指定 ``path`` 的文件中读取数据,返回DataSet


:param str path: 文件路径 :param str path: 文件路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
""" """
raise NotImplementedError raise NotImplementedError
def convert(self, data):
"""
用Python数据对象创建DataSet,各个子类需要自行实现这个方法


:param data: Python 内置的数据结构
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
def process(self, paths: Union[str, dict], **options) -> Union[DataInfo, dict]:
"""读取并处理数据,返回处理结果
""" """
raise NotImplementedError raise NotImplementedError


@@ -110,21 +140,13 @@ class PeopleDailyCorpusLoader(DataSetLoader):


读取人民日报数据集 读取人民日报数据集
""" """
def __init__(self):
def __init__(self, pos=True, ner=True):
super(PeopleDailyCorpusLoader, self).__init__() super(PeopleDailyCorpusLoader, self).__init__()
self.pos = True
self.ner = True
def load(self, data_path, pos=True, ner=True):
"""
self.pos = pos
self.ner = ner


:param str data_path: 数据路径
:param bool pos: 是否使用词性标签
:param bool ner: 是否使用命名实体标签
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
self.pos, self.ner = pos, ner
def _load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines() sents = f.readlines()
examples = [] examples = []
@@ -168,10 +190,10 @@ class PeopleDailyCorpusLoader(DataSetLoader):
example.append(sent_ner) example.append(sent_ner)
examples.append(example) examples.append(example)
return self.convert(examples) return self.convert(examples)
def convert(self, data): def convert(self, data):
""" """
:param data: python 内置对象 :param data: python 内置对象
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
""" """
@@ -179,7 +201,8 @@ class PeopleDailyCorpusLoader(DataSetLoader):
for item in data: for item in data:
sent_words = item[0] sent_words = item[0]
if self.pos is True and self.ner is True: if self.pos is True and self.ner is True:
instance = Instance(words=sent_words, pos_tags=item[1], ner=item[2])
instance = Instance(
words=sent_words, pos_tags=item[1], ner=item[2])
elif self.pos is True: elif self.pos is True:
instance = Instance(words=sent_words, pos_tags=item[1]) instance = Instance(words=sent_words, pos_tags=item[1])
elif self.ner is True: elif self.ner is True:
@@ -218,11 +241,12 @@ class ConllLoader(DataSetLoader):
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
""" """
def __init__(self, headers, indexes=None, dropna=False): def __init__(self, headers, indexes=None, dropna=False):
super(ConllLoader, self).__init__() super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)): if not isinstance(headers, (list, tuple)):
raise TypeError('invalid headers: {}, should be list of strings'.format(headers))
raise TypeError(
'invalid headers: {}, should be list of strings'.format(headers))
self.headers = headers self.headers = headers
self.dropna = dropna self.dropna = dropna
if indexes is None: if indexes is None:
@@ -231,8 +255,8 @@ class ConllLoader(DataSetLoader):
if len(indexes) != len(headers): if len(indexes) != len(headers):
raise ValueError raise ValueError
self.indexes = indexes self.indexes = indexes
def load(self, path):
def _load(self, path):
ds = DataSet() ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
@@ -245,11 +269,11 @@ class Conll2003Loader(ConllLoader):
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader` 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader`


读取Conll2003数据 读取Conll2003数据
关于数据集的更多信息,参考: 关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
""" """
def __init__(self): def __init__(self):
headers = [ headers = [
'tokens', 'pos', 'chunks', 'ner', 'tokens', 'pos', 'chunks', 'ner',
@@ -290,7 +314,7 @@ def _cut_long_sentence(sent, max_sample_length=200):
class SSTLoader(DataSetLoader): class SSTLoader(DataSetLoader):
""" """
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
读取SST数据集, DataSet包含fields:: 读取SST数据集, DataSet包含fields::


words: list(str) 需要分类的文本 words: list(str) 需要分类的文本
@@ -301,18 +325,18 @@ class SSTLoader(DataSetLoader):
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
""" """
def __init__(self, subtree=False, fine_grained=False): def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree self.subtree = subtree
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'} '3': 'positive', '4': 'very positive'}
if not fine_grained: if not fine_grained:
tag_v['0'] = tag_v['1'] tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3'] tag_v['4'] = tag_v['3']
self.tag_v = tag_v self.tag_v = tag_v
def load(self, path):
def _load(self, path):
""" """


:param str path: 存储数据的路径 :param str path: 存储数据的路径
@@ -328,7 +352,7 @@ class SSTLoader(DataSetLoader):
for words, tag in datas: for words, tag in datas:
ds.append(Instance(words=words, target=tag)) ds.append(Instance(words=words, target=tag))
return ds return ds
@staticmethod @staticmethod
def _get_one(data, subtree): def _get_one(data, subtree):
tree = Tree.fromstring(data) tree = Tree.fromstring(data)
@@ -350,7 +374,7 @@ class JsonLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False`` Default: ``False``
""" """
def __init__(self, fields=None, dropna=False): def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__() super(JsonLoader, self).__init__()
self.dropna = dropna self.dropna = dropna
@@ -361,8 +385,8 @@ class JsonLoader(DataSetLoader):
for k, v in fields.items(): for k, v in fields.items():
self.fields[k] = k if v is None else v self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys()) self.fields_list = list(self.fields.keys())
def load(self, path):
def _load(self, path):
ds = DataSet() ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields: if self.fields:
@@ -385,7 +409,7 @@ class SNLILoader(JsonLoader):


数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
""" """
def __init__(self): def __init__(self):
fields = { fields = {
'sentence1_parse': 'words1', 'sentence1_parse': 'words1',
@@ -393,16 +417,18 @@ class SNLILoader(JsonLoader):
'gold_label': 'target', 'gold_label': 'target',
} }
super(SNLILoader, self).__init__(fields=fields) super(SNLILoader, self).__init__(fields=fields)
def load(self, path):
ds = super(SNLILoader, self).load(path)
def _load(self, path):
ds = super(SNLILoader, self)._load(path)
def parse_tree(x): def parse_tree(x):
t = Tree.fromstring(x) t = Tree.fromstring(x)
return t.leaves() return t.leaves()
ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1')
ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2')

ds.apply(lambda ins: parse_tree(
ins['words1']), new_field_name='words1')
ds.apply(lambda ins: parse_tree(
ins['words2']), new_field_name='words2')
ds.drop(lambda x: x['target'] == '-') ds.drop(lambda x: x['target'] == '-')
return ds return ds


@@ -419,13 +445,13 @@ class CSVLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False`` Default: ``False``
""" """
def __init__(self, headers=None, sep=",", dropna=False): def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers self.headers = headers
self.sep = sep self.sep = sep
self.dropna = dropna self.dropna = dropna
def load(self, path):
def _load(self, path):
ds = DataSet() ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers, for idx, data in _read_csv(path, headers=self.headers,
sep=self.sep, dropna=self.dropna): sep=self.sep, dropna=self.dropna):
@@ -439,7 +465,7 @@ def _add_seg_tag(data):
:param data: list of ([word], [pos], [heads], [head_tags]) :param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos]) :return: list of ([word], [pos])
""" """
_processed = [] _processed = []
for word_list, pos_list, _, _ in data: for word_list, pos_list, _, _ in data:
new_sample = [] new_sample = []


Loading…
Cancel
Save