@@ -12,7 +12,7 @@ | |||
__all__ = [ | |||
'EmbedLoader', | |||
'DataInfo', | |||
'DataBundle', | |||
'DataSetLoader', | |||
'CSVLoader', | |||
@@ -35,7 +35,7 @@ __all__ = [ | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .base_loader import DataInfo, DataSetLoader | |||
from .base_loader import DataBundle, DataSetLoader | |||
from .dataset_loader import CSVLoader, JsonLoader | |||
from .model_io import ModelLoader, ModelSaver | |||
@@ -1,6 +1,6 @@ | |||
__all__ = [ | |||
"BaseLoader", | |||
'DataInfo', | |||
'DataBundle', | |||
'DataSetLoader', | |||
] | |||
@@ -109,7 +109,7 @@ def _uncompress(src, dst): | |||
raise ValueError('unsupported file {}'.format(src)) | |||
class DataInfo: | |||
class DataBundle: | |||
""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | |||
@@ -201,20 +201,20 @@ class DataSetLoader: | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle: | |||
""" | |||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||
返回的 :class:`DataInfo` 对象有如下属性: | |||
返回的 :class:`DataBundle` 对象有如下属性: | |||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||
:param paths: 原始数据读取的路径 | |||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||
:return: 返回一个 DataInfo | |||
:return: 返回一个 DataBundle | |||
""" | |||
raise NotImplementedError |
@@ -2,7 +2,7 @@ | |||
from typing import Union, Dict | |||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||
from ..base_loader import DataSetLoader, DataInfo | |||
from ..base_loader import DataSetLoader, DataBundle | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
@@ -48,7 +48,7 @@ class IMDBLoader(DataSetLoader): | |||
char_level_op=False): | |||
datasets = {} | |||
info = DataInfo() | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
@@ -4,7 +4,7 @@ from typing import Union, Dict , List | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ...modules.encoder._bert import BertTokenizer | |||
@@ -34,7 +34,7 @@ class MatchingLoader(DataSetLoader): | |||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, | |||
extra_split: List[str]=['-'], ) -> DataInfo: | |||
extra_split: List[str]=['-'], ) -> DataBundle: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||
@@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader): | |||
else: | |||
path = paths | |||
data_info = DataInfo() | |||
data_info = DataBundle() | |||
for data_name in path.keys(): | |||
data_info.datasets[data_name] = self._load(path[data_name]) | |||
@@ -1,7 +1,7 @@ | |||
from typing import Union, Dict | |||
from ..base_loader import DataInfo | |||
from ..base_loader import DataBundle | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||
from ...core.const import Const | |||
@@ -37,7 +37,7 @@ class MTL16Loader(CSVLoader): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
@@ -2,7 +2,7 @@ | |||
from typing import Union, Dict | |||
from nltk import Tree | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
@@ -73,7 +73,7 @@ class SSTLoader(DataSetLoader): | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
info = DataInfo() | |||
info = DataBundle() | |||
origin_subtree = self.subtree | |||
self.subtree = train_subtree | |||
info.datasets['train'] = self._load(paths['train']) | |||
@@ -126,7 +126,7 @@ class SST2Loader(CSVLoader): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
@@ -6,7 +6,7 @@ from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from typing import Union, Dict | |||
from ..utils import check_dataloader_paths, get_tokenizer | |||
@@ -58,7 +58,7 @@ class YelpLoader(DataSetLoader): | |||
tgt_vocab_op: VocabularyOption = None, | |||
char_level_op=False): | |||
paths = check_dataloader_paths(paths) | |||
info = DataInfo(datasets=self.load(paths)) | |||
info = DataBundle(datasets=self.load(paths)) | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
@@ -2,7 +2,7 @@ import pickle | |||
import numpy as np | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.base_loader import DataBundle | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.core.const import Const | |||
@@ -66,7 +66,7 @@ class SummarizationLoader(JsonLoader): | |||
:param domain: bool build vocab for publication, use 'X' for unknown | |||
:param tag: bool build vocab for tag, use 'X' for unknown | |||
:param load_vocab: bool build vocab (False) or load vocab (True) | |||
:return: DataInfo | |||
:return: DataBundle | |||
datasets: dict keys correspond to the paths dict | |||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||
embeddings: optional | |||
@@ -182,7 +182,7 @@ class SummarizationLoader(JsonLoader): | |||
for ds in datasets.values(): | |||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
return DataInfo(vocabs=vocab_dict, datasets=datasets) | |||
return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||
@@ -3,7 +3,7 @@ from datetime import timedelta | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.base_loader import DataBundle | |||
from fastNLP.core.const import Const | |||
class BertData(JsonLoader): | |||
@@ -110,7 +110,7 @@ class BertData(JsonLoader): | |||
# set paddding value | |||
datasets[name].set_pad_val('article', 0) | |||
return DataInfo(datasets=datasets) | |||
return DataBundle(datasets=datasets) | |||
class BertSumLoader(JsonLoader): | |||
@@ -154,4 +154,4 @@ class BertSumLoader(JsonLoader): | |||
print('Finished in {}'.format(timedelta(seconds=time()-start))) | |||
return DataInfo(datasets=datasets) | |||
return DataBundle(datasets=datasets) |
@@ -1,7 +1,7 @@ | |||
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | |||
from fastNLP.io.file_reader import _read_json | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.base_loader import DataBundle | |||
from reproduction.coreference_resolution.model.config import Config | |||
import reproduction.coreference_resolution.model.preprocess as preprocess | |||
@@ -26,7 +26,7 @@ class CRLoader(JsonLoader): | |||
return dataset | |||
def process(self, paths, **kwargs): | |||
data_info = DataInfo() | |||
data_info = DataBundle() | |||
for name in ['train', 'test', 'dev']: | |||
data_info.datasets[name] = self.load(paths[name]) | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_loader import ConllLoader | |||
import numpy as np | |||
@@ -76,7 +76,7 @@ class CTBxJointLoader(DataSetLoader): | |||
gold_label_word_pairs: | |||
""" | |||
paths = check_dataloader_paths(paths) | |||
data = DataInfo() | |||
data = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
@@ -5,7 +5,7 @@ from typing import Union, Dict | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||
from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
@@ -35,7 +35,7 @@ class MatchingLoader(DataSetLoader): | |||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||
@@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader): | |||
else: | |||
path = paths | |||
data_info = DataInfo() | |||
data_info = DataBundle() | |||
for data_name in path.keys(): | |||
data_info.datasets[data_name] = self._load(path[data_name]) | |||
@@ -1,7 +1,7 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -161,7 +161,7 @@ class SigHanLoader(DataSetLoader): | |||
# 推荐大家使用这个check_data_loader_paths进行paths的验证 | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
data = DataInfo() | |||
data = DataBundle() | |||
bigram = bigram_vocab_opt is not None | |||
for name, path in paths.items(): | |||
dataset = self.load(path, bigram=bigram) | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from typing import Union, Dict | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
@@ -51,7 +51,7 @@ class Conll2003DataLoader(DataSetLoader): | |||
""" | |||
# 读取数据 | |||
paths = check_dataloader_paths(paths) | |||
data = DataInfo() | |||
data = DataBundle() | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, path in paths.items(): | |||
@@ -1,5 +1,5 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from typing import Union, Dict | |||
from fastNLP import DataSet | |||
from fastNLP import Vocabulary | |||
@@ -76,7 +76,7 @@ class OntoNoteNERDataLoader(DataSetLoader): | |||
return dataset | |||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, | |||
lower:bool=True)->DataInfo: | |||
lower:bool=True)->DataBundle: | |||
""" | |||
读取并处理数据。返回的DataInfo包含以下的内容 | |||
vocabs: | |||
@@ -96,7 +96,7 @@ class OntoNoteNERDataLoader(DataSetLoader): | |||
:return: | |||
""" | |||
paths = check_dataloader_paths(paths) | |||
data = DataInfo() | |||
data = DataBundle() | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, path in paths.items(): | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -50,7 +50,7 @@ class IMDBLoader(DataSetLoader): | |||
char_level_op=False): | |||
datasets = {} | |||
info = DataInfo() | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -47,7 +47,7 @@ class MTL16Loader(DataSetLoader): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
@@ -1,6 +1,6 @@ | |||
from typing import Iterable | |||
from nltk import Tree | |||
from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||
from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -68,7 +68,7 @@ class SSTLoader(DataSetLoader): | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
info = DataInfo(datasets=self.load(paths)) | |||
info = DataBundle(datasets=self.load(paths)) | |||
_train_ds = [info.datasets[name] | |||
for name in train_ds] if train_ds else info.datasets.values() | |||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||
@@ -134,7 +134,7 @@ class sst2Loader(DataSetLoader): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
@@ -4,7 +4,7 @@ from typing import Iterable | |||
from fastNLP import DataSet, Instance, Vocabulary | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io import JsonLoader | |||
from fastNLP.io.base_loader import DataInfo,DataSetLoader | |||
from fastNLP.io.base_loader import DataBundle,DataSetLoader | |||
from fastNLP.io.embed_loader import EmbeddingOption | |||
from fastNLP.io.file_reader import _read_json | |||
from typing import Union, Dict | |||
@@ -134,7 +134,7 @@ class yelpLoader(DataSetLoader): | |||
char_level_op=False): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo(datasets=self.load(paths)) | |||
info = DataBundle(datasets=self.load(paths)) | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||