Browse Source

- update sst data loader

- add Option
tags/v0.4.10
yunfan 6 years ago
parent
commit
9a5cc3801c
3 changed files with 60 additions and 44 deletions
  1. +17
    -2
      fastNLP/core/vocabulary.py
  2. +27
    -41
      fastNLP/io/dataset_loader.py
  3. +16
    -1
      fastNLP/io/embed_loader.py

+ 17
- 2
fastNLP/core/vocabulary.py View File

@@ -1,11 +1,26 @@
__all__ = [ __all__ = [
"Vocabulary"
"Vocabulary",
"VocabularyOption",
] ]


from functools import wraps from functools import wraps
from collections import Counter from collections import Counter

from .dataset import DataSet from .dataset import DataSet
from .utils import Example


class VocabularyOption(Example):
def __init__(self,
max_size=None,
min_freq=None,
padding='<pad>',
unknown='<unk>'):
super().__init__(
max_size=max_size,
min_freq=min_freq,
padding=padding,
unknown=unknown
)




def _check_build_vocab(func): def _check_build_vocab(func):


+ 27
- 41
fastNLP/io/dataset_loader.py View File

@@ -22,8 +22,6 @@ __all__ = [
'SSTLoader', 'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
'Conll2003Loader', 'Conll2003Loader',
'VocabularyOption',
'EmbeddingOption',
] ]


from nltk.tree import Tree from nltk.tree import Tree
@@ -37,6 +35,8 @@ from ..core.utils import Example
from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary
from ..io.embed_loader import EmbedLoader from ..io.embed_loader import EmbedLoader
import numpy as np import numpy as np
from ..core.vocabulary import VocabularyOption
from .embed_loader import EmbeddingOption




def _download_from_url(url, path): def _download_from_url(url, path):
@@ -56,7 +56,6 @@ def _download_from_url(url, path):
if chunk: if chunk:
file.write(chunk) file.write(chunk)
t.update(len(chunk)) t.update(len(chunk))
return




def _uncompress(src, dst): def _uncompress(src, dst):
@@ -93,34 +92,6 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src)) raise ValueError('unsupported file {}'.format(src))




class VocabularyOption(Example):
def __init__(self,
max_size=None,
min_freq=None,
padding='<pad>',
unknown='<unk>'):
super().__init__(
max_size=max_size,
min_freq=min_freq,
padding=padding,
unknown=unknown
)


class EmbeddingOption(Example):
def __init__(self,
embed_filepath=None,
dtype=np.float32,
normalize=True,
error='ignore'):
super().__init__(
embed_filepath=embed_filepath,
dtype=dtype,
normalize=normalize,
error=error
)


class DataInfo: class DataInfo:
""" """
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
@@ -151,26 +122,41 @@ class DataSetLoader:
**process 函数中可以 调用load 函数或 _load 函数** **process 函数中可以 调用load 函数或 _load 函数**


""" """
URL = ''
DATA_DIR = ''


def _download(self, url: str, path: str, uncompress=True) -> str:
ROOT_DIR = '.fastnlp/datasets/'
UNCOMPRESS = True

def _download(self, url: str, pdir: str, uncompress=True) -> str:
""" """


从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。


:param url: 下载的网站 :param url: 下载的网站
:param path: 下载到的目录
:param pdir: 下载到的目录
:param uncompress: 是否自动解压缩 :param uncompress: 是否自动解压缩
:return: 数据的存放路径 :return: 数据的存放路径
""" """
pdir = os.path.dirname(path)
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
fn = os.path.basename(url)
path = os.path.join(pdir, fn)
"""check data exists"""
if not os.path.exists(path):
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
if uncompress: if uncompress:
dst = os.path.join(pdir, 'data') dst = os.path.join(pdir, 'data')
_uncompress(path, dst)
if not os.path.exists(dst):
_uncompress(path, dst)
return dst return dst
return path return path


def download(self):
return self._download(
self.URL,
os.path.join(self.ROOT_DIR, self.DATA_DIR),
uncompress=self.UNCOMPRESS)

def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
""" """
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
@@ -442,7 +428,7 @@ class SSTLoader(DataSetLoader):
train_ds: Iterable[str] = None, train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None, src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None,
embed_op: EmbeddingOption = None):
src_embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target' input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
@@ -463,9 +449,9 @@ class SSTLoader(DataSetLoader):
target_name: tgt_vocab target_name: tgt_vocab
} }


if embed_op is not None:
embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**embed_op)
if src_embed_op is not None:
src_embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
info.embeddings[input_name] = init_emb info.embeddings[input_name] = init_emb


return info return info


+ 16
- 1
fastNLP/io/embed_loader.py View File

@@ -1,5 +1,6 @@
__all__ = [ __all__ = [
"EmbedLoader"
"EmbedLoader",
"EmbeddingOption",
] ]


import os import os
@@ -9,8 +10,22 @@ import numpy as np


from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader from .base_loader import BaseLoader
from ..core.utils import Example




class EmbeddingOption(Example):
def __init__(self,
embed_filepath=None,
dtype=np.float32,
normalize=True,
error='ignore'):
super().__init__(
embed_filepath=embed_filepath,
dtype=dtype,
normalize=normalize,
error=error
)

class EmbedLoader(BaseLoader): class EmbedLoader(BaseLoader):
""" """
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` 别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader`


Loading…
Cancel
Save