From 9a5cc3801c48e9836cee74333ab135855a34e8c7 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 30 May 2019 14:27:12 +0800 Subject: [PATCH] - update sst data loader - add Option --- fastNLP/core/vocabulary.py | 19 ++++++++-- fastNLP/io/dataset_loader.py | 68 ++++++++++++++---------------------- fastNLP/io/embed_loader.py | 17 ++++++++- 3 files changed, 60 insertions(+), 44 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index cbde9cba..0cf45049 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,11 +1,26 @@ __all__ = [ - "Vocabulary" + "Vocabulary", + "VocabularyOption", ] from functools import wraps from collections import Counter - from .dataset import DataSet +from .utils import Example + + +class VocabularyOption(Example): + def __init__(self, + max_size=None, + min_freq=None, + padding='', + unknown=''): + super().__init__( + max_size=max_size, + min_freq=min_freq, + padding=padding, + unknown=unknown + ) def _check_build_vocab(func): diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 8c697e4a..9ad5dff8 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -22,8 +22,6 @@ __all__ = [ 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', - 'VocabularyOption', - 'EmbeddingOption', ] from nltk.tree import Tree @@ -37,6 +35,8 @@ from ..core.utils import Example from ..core.vocabulary import Vocabulary from ..io.embed_loader import EmbedLoader import numpy as np +from ..core.vocabulary import VocabularyOption +from .embed_loader import EmbeddingOption def _download_from_url(url, path): @@ -56,7 +56,6 @@ def _download_from_url(url, path): if chunk: file.write(chunk) t.update(len(chunk)) - return def _uncompress(src, dst): @@ -93,34 +92,6 @@ def _uncompress(src, dst): raise ValueError('unsupported file {}'.format(src)) -class VocabularyOption(Example): - def __init__(self, - max_size=None, - min_freq=None, - padding='', - unknown=''): - super().__init__( - max_size=max_size, - min_freq=min_freq, - padding=padding, - unknown=unknown - ) - - -class EmbeddingOption(Example): - def __init__(self, - embed_filepath=None, - dtype=np.float32, - normalize=True, - error='ignore'): - super().__init__( - embed_filepath=embed_filepath, - dtype=dtype, - normalize=normalize, - error=error - ) - - class DataInfo: """ 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 @@ -151,26 +122,41 @@ class DataSetLoader: **process 函数中可以 调用load 函数或 _load 函数** """ + URL = '' + DATA_DIR = '' - def _download(self, url: str, path: str, uncompress=True) -> str: + ROOT_DIR = '.fastnlp/datasets/' + UNCOMPRESS = True + + def _download(self, url: str, pdir: str, uncompress=True) -> str: """ 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 :param url: 下载的网站 - :param path: 下载到的目录 + :param pdir: 下载到的目录 :param uncompress: 是否自动解压缩 :return: 数据的存放路径 """ - pdir = os.path.dirname(path) - os.makedirs(pdir, exist_ok=True) - _download_from_url(url, path) + fn = os.path.basename(url) + path = os.path.join(pdir, fn) + """check data exists""" + if not os.path.exists(path): + os.makedirs(pdir, exist_ok=True) + _download_from_url(url, path) if uncompress: dst = os.path.join(pdir, 'data') - _uncompress(path, dst) + if not os.path.exists(dst): + _uncompress(path, dst) return dst return path + def download(self): + return self._download( + self.URL, + os.path.join(self.ROOT_DIR, self.DATA_DIR), + uncompress=self.UNCOMPRESS) + def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: """ 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 @@ -442,7 +428,7 @@ class SSTLoader(DataSetLoader): train_ds: Iterable[str] = None, src_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None, - embed_op: EmbeddingOption = None): + src_embed_op: EmbeddingOption = None): input_name, target_name = 'words', 'target' src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) @@ -463,9 +449,9 @@ class SSTLoader(DataSetLoader): target_name: tgt_vocab } - if embed_op is not None: - embed_op.vocab = src_vocab - init_emb = EmbedLoader.load_with_vocab(**embed_op) + if src_embed_op is not None: + src_embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**src_embed_op) info.embeddings[input_name] = init_emb return info diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index fb024e73..93861258 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -1,5 +1,6 @@ __all__ = [ - "EmbedLoader" + "EmbedLoader", + "EmbeddingOption", ] import os @@ -9,8 +10,22 @@ import numpy as np from ..core.vocabulary import Vocabulary from .base_loader import BaseLoader +from ..core.utils import Example +class EmbeddingOption(Example): + def __init__(self, + embed_filepath=None, + dtype=np.float32, + normalize=True, + error='ignore'): + super().__init__( + embed_filepath=embed_filepath, + dtype=dtype, + normalize=normalize, + error=error + ) + class EmbedLoader(BaseLoader): """ 别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader`