From 6862a8f16979565d2aa915aec0a0974cdae2350d Mon Sep 17 00:00:00 2001 From: yunfan Date: Wed, 22 May 2019 11:19:14 +0800 Subject: [PATCH] - update dataset_loader --- fastNLP/io/dataset_loader.py | 140 +++++++++++++++++++++-------------- 1 file changed, 83 insertions(+), 57 deletions(-) diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 8273d2f8..19600ef7 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -26,6 +26,8 @@ from nltk.tree import Tree from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll +from typing import Union +import os def _download_from_url(url, path): @@ -34,7 +36,7 @@ def _download_from_url(url, path): except: from ..core.utils import _pseudo_tqdm as tqdm import requests - + """Download file""" r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) chunk_size = 16 * 1024 @@ -49,12 +51,15 @@ def _download_from_url(url, path): def _uncompress(src, dst): - import zipfile, gzip, tarfile, os - + import zipfile + import gzip + import tarfile + import os + def unzip(src, dst): with zipfile.ZipFile(src, 'r') as f: f.extractall(dst) - + def ungz(src, dst): with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: length = 16 * 1024 # 16KB @@ -62,11 +67,11 @@ def _uncompress(src, dst): while buf: uf.write(buf) buf = f.read(length) - + def untar(src, dst): with tarfile.open(src, 'r:gz') as f: f.extractall(dst) - + fn, ext = os.path.splitext(src) _, ext_2 = os.path.splitext(fn) if ext == '.zip': @@ -79,27 +84,52 @@ def _uncompress(src, dst): raise ValueError('unsupported file {}'.format(src)) +class DataInfo(): + def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): + self.vocabs = vocabs or {} + self.embeddings = embeddings or {} + self.datasets = datasets or {} + + class DataSetLoader: """ 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` 所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader """ - - def load(self, path): + def _download(self, url: str, path: str, uncompress=True) -> str: + """从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 + 返回数据的路径。 + """ + pdir = os.path.dirname(path) + os.makedirs(pdir, exist_ok=True) + _download_from_url(url, path) + if uncompress: + dst = os.path.join(pdir, 'data') + _uncompress(path, dst) + return dst + return path + + def load(self, paths: Union[str, dict]) -> Union[DataSet, dict]: + """从指定一个或多个 ``paths`` 的文件中读取数据,返回DataSet + + :param str or dict paths: 文件路径 + :return: 一个存储 :class:`~fastNLP.DataSet` 的字典 + """ + if isinstance(paths, str): + return self._load(paths) + return {name: self._load(path) for name, path in paths.items()} + + def _load(self, path: str) -> DataSet: """从指定 ``path`` 的文件中读取数据,返回DataSet :param str path: 文件路径 :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 """ raise NotImplementedError - - def convert(self, data): - """ - 用Python数据对象创建DataSet,各个子类需要自行实现这个方法 - :param data: Python 内置的数据结构 - :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + def process(self, paths: Union[str, dict], **options) -> Union[DataInfo, dict]: + """读取并处理数据,返回处理结果 """ raise NotImplementedError @@ -110,21 +140,13 @@ class PeopleDailyCorpusLoader(DataSetLoader): 读取人民日报数据集 """ - - def __init__(self): + + def __init__(self, pos=True, ner=True): super(PeopleDailyCorpusLoader, self).__init__() - self.pos = True - self.ner = True - - def load(self, data_path, pos=True, ner=True): - """ + self.pos = pos + self.ner = ner - :param str data_path: 数据路径 - :param bool pos: 是否使用词性标签 - :param bool ner: 是否使用命名实体标签 - :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 - """ - self.pos, self.ner = pos, ner + def _load(self, data_path): with open(data_path, "r", encoding="utf-8") as f: sents = f.readlines() examples = [] @@ -168,10 +190,10 @@ class PeopleDailyCorpusLoader(DataSetLoader): example.append(sent_ner) examples.append(example) return self.convert(examples) - + def convert(self, data): """ - + :param data: python 内置对象 :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 """ @@ -179,7 +201,8 @@ class PeopleDailyCorpusLoader(DataSetLoader): for item in data: sent_words = item[0] if self.pos is True and self.ner is True: - instance = Instance(words=sent_words, pos_tags=item[1], ner=item[2]) + instance = Instance( + words=sent_words, pos_tags=item[1], ner=item[2]) elif self.pos is True: instance = Instance(words=sent_words, pos_tags=item[1]) elif self.ner is True: @@ -218,11 +241,12 @@ class ConllLoader(DataSetLoader): :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` """ - + def __init__(self, headers, indexes=None, dropna=False): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): - raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) + raise TypeError( + 'invalid headers: {}, should be list of strings'.format(headers)) self.headers = headers self.dropna = dropna if indexes is None: @@ -231,8 +255,8 @@ class ConllLoader(DataSetLoader): if len(indexes) != len(headers): raise ValueError self.indexes = indexes - - def load(self, path): + + def _load(self, path): ds = DataSet() for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): ins = {h: data[i] for i, h in enumerate(self.headers)} @@ -245,11 +269,11 @@ class Conll2003Loader(ConllLoader): 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader` 读取Conll2003数据 - + 关于数据集的更多信息,参考: https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ - + def __init__(self): headers = [ 'tokens', 'pos', 'chunks', 'ner', @@ -290,7 +314,7 @@ def _cut_long_sentence(sent, max_sample_length=200): class SSTLoader(DataSetLoader): """ 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` - + 读取SST数据集, DataSet包含fields:: words: list(str) 需要分类的文本 @@ -301,18 +325,18 @@ class SSTLoader(DataSetLoader): :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - + def __init__(self, subtree=False, fine_grained=False): self.subtree = subtree - + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', '3': 'positive', '4': 'very positive'} if not fine_grained: tag_v['0'] = tag_v['1'] tag_v['4'] = tag_v['3'] self.tag_v = tag_v - - def load(self, path): + + def _load(self, path): """ :param str path: 存储数据的路径 @@ -328,7 +352,7 @@ class SSTLoader(DataSetLoader): for words, tag in datas: ds.append(Instance(words=words, target=tag)) return ds - + @staticmethod def _get_one(data, subtree): tree = Tree.fromstring(data) @@ -350,7 +374,7 @@ class JsonLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, fields=None, dropna=False): super(JsonLoader, self).__init__() self.dropna = dropna @@ -361,8 +385,8 @@ class JsonLoader(DataSetLoader): for k, v in fields.items(): self.fields[k] = k if v is None else v self.fields_list = list(self.fields.keys()) - - def load(self, path): + + def _load(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): if self.fields: @@ -385,7 +409,7 @@ class SNLILoader(JsonLoader): 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip """ - + def __init__(self): fields = { 'sentence1_parse': 'words1', @@ -393,16 +417,18 @@ class SNLILoader(JsonLoader): 'gold_label': 'target', } super(SNLILoader, self).__init__(fields=fields) - - def load(self, path): - ds = super(SNLILoader, self).load(path) - + + def _load(self, path): + ds = super(SNLILoader, self)._load(path) + def parse_tree(x): t = Tree.fromstring(x) return t.leaves() - - ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') - ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') + + ds.apply(lambda ins: parse_tree( + ins['words1']), new_field_name='words1') + ds.apply(lambda ins: parse_tree( + ins['words2']), new_field_name='words2') ds.drop(lambda x: x['target'] == '-') return ds @@ -419,13 +445,13 @@ class CSVLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, headers=None, sep=",", dropna=False): self.headers = headers self.sep = sep self.dropna = dropna - - def load(self, path): + + def _load(self, path): ds = DataSet() for idx, data in _read_csv(path, headers=self.headers, sep=self.sep, dropna=self.dropna): @@ -439,7 +465,7 @@ def _add_seg_tag(data): :param data: list of ([word], [pos], [heads], [head_tags]) :return: list of ([word], [pos]) """ - + _processed = [] for word_list, pos_list, _, _ in data: new_sample = []