|
@@ -22,8 +22,6 @@ __all__ = [ |
|
|
'SSTLoader', |
|
|
'SSTLoader', |
|
|
'PeopleDailyCorpusLoader', |
|
|
'PeopleDailyCorpusLoader', |
|
|
'Conll2003Loader', |
|
|
'Conll2003Loader', |
|
|
'VocabularyOption', |
|
|
|
|
|
'EmbeddingOption', |
|
|
|
|
|
] |
|
|
] |
|
|
|
|
|
|
|
|
from nltk.tree import Tree |
|
|
from nltk.tree import Tree |
|
@@ -37,6 +35,8 @@ from ..core.utils import Example |
|
|
from ..core.vocabulary import Vocabulary |
|
|
from ..core.vocabulary import Vocabulary |
|
|
from ..io.embed_loader import EmbedLoader |
|
|
from ..io.embed_loader import EmbedLoader |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
from ..core.vocabulary import VocabularyOption |
|
|
|
|
|
from .embed_loader import EmbeddingOption |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download_from_url(url, path): |
|
|
def _download_from_url(url, path): |
|
@@ -56,7 +56,6 @@ def _download_from_url(url, path): |
|
|
if chunk: |
|
|
if chunk: |
|
|
file.write(chunk) |
|
|
file.write(chunk) |
|
|
t.update(len(chunk)) |
|
|
t.update(len(chunk)) |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _uncompress(src, dst): |
|
|
def _uncompress(src, dst): |
|
@@ -93,34 +92,6 @@ def _uncompress(src, dst): |
|
|
raise ValueError('unsupported file {}'.format(src)) |
|
|
raise ValueError('unsupported file {}'.format(src)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VocabularyOption(Example): |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
max_size=None, |
|
|
|
|
|
min_freq=None, |
|
|
|
|
|
padding='<pad>', |
|
|
|
|
|
unknown='<unk>'): |
|
|
|
|
|
super().__init__( |
|
|
|
|
|
max_size=max_size, |
|
|
|
|
|
min_freq=min_freq, |
|
|
|
|
|
padding=padding, |
|
|
|
|
|
unknown=unknown |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingOption(Example): |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
embed_filepath=None, |
|
|
|
|
|
dtype=np.float32, |
|
|
|
|
|
normalize=True, |
|
|
|
|
|
error='ignore'): |
|
|
|
|
|
super().__init__( |
|
|
|
|
|
embed_filepath=embed_filepath, |
|
|
|
|
|
dtype=dtype, |
|
|
|
|
|
normalize=normalize, |
|
|
|
|
|
error=error |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataInfo: |
|
|
class DataInfo: |
|
|
""" |
|
|
""" |
|
|
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 |
|
|
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 |
|
@@ -151,26 +122,41 @@ class DataSetLoader: |
|
|
**process 函数中可以 调用load 函数或 _load 函数** |
|
|
**process 函数中可以 调用load 函数或 _load 函数** |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
URL = '' |
|
|
|
|
|
DATA_DIR = '' |
|
|
|
|
|
|
|
|
def _download(self, url: str, path: str, uncompress=True) -> str: |
|
|
|
|
|
|
|
|
ROOT_DIR = '.fastnlp/datasets/' |
|
|
|
|
|
UNCOMPRESS = True |
|
|
|
|
|
|
|
|
|
|
|
def _download(self, url: str, pdir: str, uncompress=True) -> str: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 |
|
|
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 |
|
|
|
|
|
|
|
|
:param url: 下载的网站 |
|
|
:param url: 下载的网站 |
|
|
:param path: 下载到的目录 |
|
|
|
|
|
|
|
|
:param pdir: 下载到的目录 |
|
|
:param uncompress: 是否自动解压缩 |
|
|
:param uncompress: 是否自动解压缩 |
|
|
:return: 数据的存放路径 |
|
|
:return: 数据的存放路径 |
|
|
""" |
|
|
""" |
|
|
pdir = os.path.dirname(path) |
|
|
|
|
|
os.makedirs(pdir, exist_ok=True) |
|
|
|
|
|
_download_from_url(url, path) |
|
|
|
|
|
|
|
|
fn = os.path.basename(url) |
|
|
|
|
|
path = os.path.join(pdir, fn) |
|
|
|
|
|
"""check data exists""" |
|
|
|
|
|
if not os.path.exists(path): |
|
|
|
|
|
os.makedirs(pdir, exist_ok=True) |
|
|
|
|
|
_download_from_url(url, path) |
|
|
if uncompress: |
|
|
if uncompress: |
|
|
dst = os.path.join(pdir, 'data') |
|
|
dst = os.path.join(pdir, 'data') |
|
|
_uncompress(path, dst) |
|
|
|
|
|
|
|
|
if not os.path.exists(dst): |
|
|
|
|
|
_uncompress(path, dst) |
|
|
return dst |
|
|
return dst |
|
|
return path |
|
|
return path |
|
|
|
|
|
|
|
|
|
|
|
def download(self): |
|
|
|
|
|
return self._download( |
|
|
|
|
|
self.URL, |
|
|
|
|
|
os.path.join(self.ROOT_DIR, self.DATA_DIR), |
|
|
|
|
|
uncompress=self.UNCOMPRESS) |
|
|
|
|
|
|
|
|
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: |
|
|
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: |
|
|
""" |
|
|
""" |
|
|
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 |
|
|
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 |
|
@@ -442,7 +428,7 @@ class SSTLoader(DataSetLoader): |
|
|
train_ds: Iterable[str] = None, |
|
|
train_ds: Iterable[str] = None, |
|
|
src_vocab_op: VocabularyOption = None, |
|
|
src_vocab_op: VocabularyOption = None, |
|
|
tgt_vocab_op: VocabularyOption = None, |
|
|
tgt_vocab_op: VocabularyOption = None, |
|
|
embed_op: EmbeddingOption = None): |
|
|
|
|
|
|
|
|
src_embed_op: EmbeddingOption = None): |
|
|
input_name, target_name = 'words', 'target' |
|
|
input_name, target_name = 'words', 'target' |
|
|
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) |
|
|
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) |
|
|
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) |
|
|
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) |
|
@@ -463,9 +449,9 @@ class SSTLoader(DataSetLoader): |
|
|
target_name: tgt_vocab |
|
|
target_name: tgt_vocab |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if embed_op is not None: |
|
|
|
|
|
embed_op.vocab = src_vocab |
|
|
|
|
|
init_emb = EmbedLoader.load_with_vocab(**embed_op) |
|
|
|
|
|
|
|
|
if src_embed_op is not None: |
|
|
|
|
|
src_embed_op.vocab = src_vocab |
|
|
|
|
|
init_emb = EmbedLoader.load_with_vocab(**src_embed_op) |
|
|
info.embeddings[input_name] = init_emb |
|
|
info.embeddings[input_name] = init_emb |
|
|
|
|
|
|
|
|
return info |
|
|
return info |
|
|