Browse Source

- update sst data loader

- add Option
tags/v0.4.10
yunfan 6 years ago
parent
commit
9a5cc3801c
3 changed files with 60 additions and 44 deletions
  1. +17
    -2
      fastNLP/core/vocabulary.py
  2. +27
    -41
      fastNLP/io/dataset_loader.py
  3. +16
    -1
      fastNLP/io/embed_loader.py

+ 17
- 2
fastNLP/core/vocabulary.py View File

@@ -1,11 +1,26 @@
__all__ = [
"Vocabulary"
"Vocabulary",
"VocabularyOption",
]

from functools import wraps
from collections import Counter

from .dataset import DataSet
from .utils import Example


class VocabularyOption(Example):
def __init__(self,
max_size=None,
min_freq=None,
padding='<pad>',
unknown='<unk>'):
super().__init__(
max_size=max_size,
min_freq=min_freq,
padding=padding,
unknown=unknown
)


def _check_build_vocab(func):


+ 27
- 41
fastNLP/io/dataset_loader.py View File

@@ -22,8 +22,6 @@ __all__ = [
'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
'VocabularyOption',
'EmbeddingOption',
]

from nltk.tree import Tree
@@ -37,6 +35,8 @@ from ..core.utils import Example
from ..core.vocabulary import Vocabulary
from ..io.embed_loader import EmbedLoader
import numpy as np
from ..core.vocabulary import VocabularyOption
from .embed_loader import EmbeddingOption


def _download_from_url(url, path):
@@ -56,7 +56,6 @@ def _download_from_url(url, path):
if chunk:
file.write(chunk)
t.update(len(chunk))
return


def _uncompress(src, dst):
@@ -93,34 +92,6 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src))


class VocabularyOption(Example):
def __init__(self,
max_size=None,
min_freq=None,
padding='<pad>',
unknown='<unk>'):
super().__init__(
max_size=max_size,
min_freq=min_freq,
padding=padding,
unknown=unknown
)


class EmbeddingOption(Example):
def __init__(self,
embed_filepath=None,
dtype=np.float32,
normalize=True,
error='ignore'):
super().__init__(
embed_filepath=embed_filepath,
dtype=dtype,
normalize=normalize,
error=error
)


class DataInfo:
"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
@@ -151,26 +122,41 @@ class DataSetLoader:
**process 函数中可以 调用load 函数或 _load 函数**

"""
URL = ''
DATA_DIR = ''

def _download(self, url: str, path: str, uncompress=True) -> str:
ROOT_DIR = '.fastnlp/datasets/'
UNCOMPRESS = True

def _download(self, url: str, pdir: str, uncompress=True) -> str:
"""

从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。

:param url: 下载的网站
:param path: 下载到的目录
:param pdir: 下载到的目录
:param uncompress: 是否自动解压缩
:return: 数据的存放路径
"""
pdir = os.path.dirname(path)
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
fn = os.path.basename(url)
path = os.path.join(pdir, fn)
"""check data exists"""
if not os.path.exists(path):
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
if uncompress:
dst = os.path.join(pdir, 'data')
_uncompress(path, dst)
if not os.path.exists(dst):
_uncompress(path, dst)
return dst
return path

def download(self):
return self._download(
self.URL,
os.path.join(self.ROOT_DIR, self.DATA_DIR),
uncompress=self.UNCOMPRESS)

def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
"""
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
@@ -442,7 +428,7 @@ class SSTLoader(DataSetLoader):
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
embed_op: EmbeddingOption = None):
src_embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
@@ -463,9 +449,9 @@ class SSTLoader(DataSetLoader):
target_name: tgt_vocab
}

if embed_op is not None:
embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**embed_op)
if src_embed_op is not None:
src_embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
info.embeddings[input_name] = init_emb

return info


+ 16
- 1
fastNLP/io/embed_loader.py View File

@@ -1,5 +1,6 @@
__all__ = [
"EmbedLoader"
"EmbedLoader",
"EmbeddingOption",
]

import os
@@ -9,8 +10,22 @@ import numpy as np

from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader
from ..core.utils import Example


class EmbeddingOption(Example):
def __init__(self,
embed_filepath=None,
dtype=np.float32,
normalize=True,
error='ignore'):
super().__init__(
embed_filepath=embed_filepath,
dtype=dtype,
normalize=normalize,
error=error
)

class EmbedLoader(BaseLoader):
"""
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader`


Loading…
Cancel
Save