diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index fe4ca245..d1d1dc5d 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -12,7 +12,7 @@ __all__ = [ 'EmbedLoader', - 'DataInfo', + 'DataBundle', 'DataSetLoader', 'CSVLoader', @@ -35,7 +35,7 @@ __all__ = [ ] from .embed_loader import EmbedLoader -from .base_loader import DataInfo, DataSetLoader +from .base_loader import DataBundle, DataSetLoader from .dataset_loader import CSVLoader, JsonLoader from .model_io import ModelLoader, ModelSaver diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 8cff1da1..62793836 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -1,6 +1,6 @@ __all__ = [ "BaseLoader", - 'DataInfo', + 'DataBundle', 'DataSetLoader', ] @@ -109,7 +109,7 @@ def _uncompress(src, dst): raise ValueError('unsupported file {}'.format(src)) -class DataInfo: +class DataBundle: """ 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 @@ -201,20 +201,20 @@ class DataSetLoader: """ raise NotImplementedError - def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: + def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle: """ 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 - 返回的 :class:`DataInfo` 对象有如下属性: + 返回的 :class:`DataBundle` 对象有如下属性: - vocabs: 由从数据集中获取的词表组成的字典,每个词表 - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` :param paths: 原始数据读取的路径 :param options: 根据不同的任务和数据集,设计自己的参数 - :return: 返回一个 DataInfo + :return: 返回一个 DataBundle """ raise NotImplementedError diff --git a/fastNLP/io/data_loader/imdb.py b/fastNLP/io/data_loader/imdb.py index b4c2c1f9..bf53c5be 100644 --- a/fastNLP/io/data_loader/imdb.py +++ b/fastNLP/io/data_loader/imdb.py @@ -2,7 +2,7 @@ from typing import Union, Dict from ..embed_loader import EmbeddingOption, EmbedLoader -from ..base_loader import DataSetLoader, DataInfo +from ..base_loader import DataSetLoader, DataBundle from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.instance import Instance @@ -48,7 +48,7 @@ class IMDBLoader(DataSetLoader): char_level_op=False): datasets = {} - info = DataInfo() + info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index ce9c280b..3f5759d6 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -4,7 +4,7 @@ from typing import Union, Dict , List from ...core.const import Const from ...core.vocabulary import Vocabulary -from ..base_loader import DataInfo, DataSetLoader +from ..base_loader import DataBundle, DataSetLoader from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ...modules.encoder._bert import BertTokenizer @@ -34,7 +34,7 @@ class MatchingLoader(DataSetLoader): cut_text: int = None, get_index=True, auto_pad_length: int=None, auto_pad_token: str='', set_input: Union[list, str, bool]=True, set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, - extra_split: List[str]=['-'], ) -> DataInfo: + extra_split: List[str]=['-'], ) -> DataBundle: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 @@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader): else: path = paths - data_info = DataInfo() + data_info = DataBundle() for data_name in path.keys(): data_info.datasets[data_name] = self._load(path[data_name]) diff --git a/fastNLP/io/data_loader/mtl.py b/fastNLP/io/data_loader/mtl.py index 548a985b..940ece51 100644 --- a/fastNLP/io/data_loader/mtl.py +++ b/fastNLP/io/data_loader/mtl.py @@ -1,7 +1,7 @@ from typing import Union, Dict -from ..base_loader import DataInfo +from ..base_loader import DataBundle from ..dataset_loader import CSVLoader from ...core.vocabulary import Vocabulary, VocabularyOption from ...core.const import Const @@ -37,7 +37,7 @@ class MTL16Loader(CSVLoader): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() + info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 05d63e2f..ecbabd49 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -2,7 +2,7 @@ from typing import Union, Dict from nltk import Tree -from ..base_loader import DataInfo, DataSetLoader +from ..base_loader import DataBundle, DataSetLoader from ..dataset_loader import CSVLoader from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet @@ -73,7 +73,7 @@ class SSTLoader(DataSetLoader): tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) - info = DataInfo() + info = DataBundle() origin_subtree = self.subtree self.subtree = train_subtree info.datasets['train'] = self._load(paths['train']) @@ -126,7 +126,7 @@ class SST2Loader(CSVLoader): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() + info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset diff --git a/fastNLP/io/data_loader/yelp.py b/fastNLP/io/data_loader/yelp.py index 41856200..c287a90c 100644 --- a/fastNLP/io/data_loader/yelp.py +++ b/fastNLP/io/data_loader/yelp.py @@ -6,7 +6,7 @@ from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance from ...core.vocabulary import VocabularyOption, Vocabulary -from ..base_loader import DataInfo, DataSetLoader +from ..base_loader import DataBundle, DataSetLoader from typing import Union, Dict from ..utils import check_dataloader_paths, get_tokenizer @@ -58,7 +58,7 @@ class YelpLoader(DataSetLoader): tgt_vocab_op: VocabularyOption = None, char_level_op=False): paths = check_dataloader_paths(paths) - info = DataInfo(datasets=self.load(paths)) + info = DataBundle(datasets=self.load(paths)) src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) diff --git a/reproduction/Summarization/Baseline/data/dataloader.py b/reproduction/Summarization/Baseline/data/dataloader.py index fe787c31..57702904 100644 --- a/reproduction/Summarization/Baseline/data/dataloader.py +++ b/reproduction/Summarization/Baseline/data/dataloader.py @@ -2,7 +2,7 @@ import pickle import numpy as np from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataInfo +from fastNLP.io.base_loader import DataBundle from fastNLP.io.dataset_loader import JsonLoader from fastNLP.core.const import Const @@ -66,7 +66,7 @@ class SummarizationLoader(JsonLoader): :param domain: bool build vocab for publication, use 'X' for unknown :param tag: bool build vocab for tag, use 'X' for unknown :param load_vocab: bool build vocab (False) or load vocab (True) - :return: DataInfo + :return: DataBundle datasets: dict keys correspond to the paths dict vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) embeddings: optional @@ -182,7 +182,7 @@ class SummarizationLoader(JsonLoader): for ds in datasets.values(): vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) - return DataInfo(vocabs=vocab_dict, datasets=datasets) + return DataBundle(vocabs=vocab_dict, datasets=datasets) diff --git a/reproduction/Summarization/BertSum/dataloader.py b/reproduction/Summarization/BertSum/dataloader.py index cb1acd53..c5201261 100644 --- a/reproduction/Summarization/BertSum/dataloader.py +++ b/reproduction/Summarization/BertSum/dataloader.py @@ -3,7 +3,7 @@ from datetime import timedelta from fastNLP.io.dataset_loader import JsonLoader from fastNLP.modules.encoder._bert import BertTokenizer -from fastNLP.io.base_loader import DataInfo +from fastNLP.io.base_loader import DataBundle from fastNLP.core.const import Const class BertData(JsonLoader): @@ -110,7 +110,7 @@ class BertData(JsonLoader): # set paddding value datasets[name].set_pad_val('article', 0) - return DataInfo(datasets=datasets) + return DataBundle(datasets=datasets) class BertSumLoader(JsonLoader): @@ -154,4 +154,4 @@ class BertSumLoader(JsonLoader): print('Finished in {}'.format(timedelta(seconds=time()-start))) - return DataInfo(datasets=datasets) + return DataBundle(datasets=datasets) diff --git a/reproduction/coreference_resolution/data_load/cr_loader.py b/reproduction/coreference_resolution/data_load/cr_loader.py index 986afcd5..a424b0d1 100644 --- a/reproduction/coreference_resolution/data_load/cr_loader.py +++ b/reproduction/coreference_resolution/data_load/cr_loader.py @@ -1,7 +1,7 @@ from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance from fastNLP.io.file_reader import _read_json from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataInfo +from fastNLP.io.base_loader import DataBundle from reproduction.coreference_resolution.model.config import Config import reproduction.coreference_resolution.model.preprocess as preprocess @@ -26,7 +26,7 @@ class CRLoader(JsonLoader): return dataset def process(self, paths, **kwargs): - data_info = DataInfo() + data_info = DataBundle() for name in ['train', 'test', 'dev']: data_info.datasets[name] = self.load(paths[name]) diff --git a/reproduction/joint_cws_parse/data/data_loader.py b/reproduction/joint_cws_parse/data/data_loader.py index 0644b01d..3e6fec4b 100644 --- a/reproduction/joint_cws_parse/data/data_loader.py +++ b/reproduction/joint_cws_parse/data/data_loader.py @@ -1,6 +1,6 @@ -from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.base_loader import DataSetLoader, DataBundle from fastNLP.io.data_loader import ConllLoader import numpy as np @@ -76,7 +76,7 @@ class CTBxJointLoader(DataSetLoader): gold_label_word_pairs: """ paths = check_dataloader_paths(paths) - data = DataInfo() + data = DataBundle() for name, path in paths.items(): dataset = self.load(path) diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 7c32899c..67fa4c8d 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -5,7 +5,7 @@ from typing import Union, Dict from fastNLP.core.const import Const from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.io.base_loader import DataBundle, DataSetLoader from fastNLP.io.dataset_loader import JsonLoader, CSVLoader from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.modules.encoder._bert import BertTokenizer @@ -35,7 +35,7 @@ class MatchingLoader(DataSetLoader): to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, cut_text: int = None, get_index=True, auto_pad_length: int=None, auto_pad_token: str='', set_input: Union[list, str, bool]=True, - set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: + set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 @@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader): else: path = paths - data_info = DataInfo() + data_info = DataBundle() for data_name in path.keys(): data_info.datasets[data_name] = self._load(path[data_name]) diff --git a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py index e8440289..3c82d814 100644 --- a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py +++ b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py @@ -1,7 +1,7 @@ from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.base_loader import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance @@ -161,7 +161,7 @@ class SigHanLoader(DataSetLoader): # 推荐大家使用这个check_data_loader_paths进行paths的验证 paths = check_dataloader_paths(paths) datasets = {} - data = DataInfo() + data = DataBundle() bigram = bigram_vocab_opt is not None for name, path in paths.items(): dataset = self.load(path, bigram=bigram) diff --git a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py index 0e464640..1aeddcf8 100644 --- a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py +++ b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py @@ -1,6 +1,6 @@ from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.base_loader import DataSetLoader, DataBundle from typing import Union, Dict from fastNLP import Vocabulary from fastNLP import Const @@ -51,7 +51,7 @@ class Conll2003DataLoader(DataSetLoader): """ # 读取数据 paths = check_dataloader_paths(paths) - data = DataInfo() + data = DataBundle() input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN] for name, path in paths.items(): diff --git a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py index 8a2c567d..f1ff83d8 100644 --- a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py +++ b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py @@ -1,5 +1,5 @@ from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.base_loader import DataSetLoader, DataBundle from typing import Union, Dict from fastNLP import DataSet from fastNLP import Vocabulary @@ -76,7 +76,7 @@ class OntoNoteNERDataLoader(DataSetLoader): return dataset def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, - lower:bool=True)->DataInfo: + lower:bool=True)->DataBundle: """ 读取并处理数据。返回的DataInfo包含以下的内容 vocabs: @@ -96,7 +96,7 @@ class OntoNoteNERDataLoader(DataSetLoader): :return: """ paths = check_dataloader_paths(paths) - data = DataInfo() + data = DataBundle() input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN] for name, path in paths.items(): diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py index 0cdab15e..d57ee41b 100644 --- a/reproduction/text_classification/data/IMDBLoader.py +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -1,6 +1,6 @@ from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.base_loader import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance @@ -50,7 +50,7 @@ class IMDBLoader(DataSetLoader): char_level_op=False): datasets = {} - info = DataInfo() + info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset diff --git a/reproduction/text_classification/data/MTL16Loader.py b/reproduction/text_classification/data/MTL16Loader.py index 066b53b4..68969069 100644 --- a/reproduction/text_classification/data/MTL16Loader.py +++ b/reproduction/text_classification/data/MTL16Loader.py @@ -1,6 +1,6 @@ from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.base_loader import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance @@ -47,7 +47,7 @@ class MTL16Loader(DataSetLoader): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() + info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset diff --git a/reproduction/text_classification/data/sstloader.py b/reproduction/text_classification/data/sstloader.py index 97cd935e..fa4d1837 100644 --- a/reproduction/text_classification/data/sstloader.py +++ b/reproduction/text_classification/data/sstloader.py @@ -1,6 +1,6 @@ from typing import Iterable from nltk import Tree -from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.io.base_loader import DataBundle, DataSetLoader from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP import DataSet from fastNLP import Instance @@ -68,7 +68,7 @@ class SSTLoader(DataSetLoader): tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) - info = DataInfo(datasets=self.load(paths)) + info = DataBundle(datasets=self.load(paths)) _train_ds = [info.datasets[name] for name in train_ds] if train_ds else info.datasets.values() src_vocab.from_dataset(*_train_ds, field_name=input_name) @@ -134,7 +134,7 @@ class sst2Loader(DataSetLoader): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() + info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index f34cfbbf..04ae94f9 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -4,7 +4,7 @@ from typing import Iterable from fastNLP import DataSet, Instance, Vocabulary from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io import JsonLoader -from fastNLP.io.base_loader import DataInfo,DataSetLoader +from fastNLP.io.base_loader import DataBundle,DataSetLoader from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict @@ -134,7 +134,7 @@ class yelpLoader(DataSetLoader): char_level_op=False): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo(datasets=self.load(paths)) + info = DataBundle(datasets=self.load(paths)) src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)