Browse Source

将DataInfo修改为DataBundle

tags/v0.4.10
yh 6 years ago
parent
commit
5c80c6f4f5
19 changed files with 48 additions and 48 deletions
  1. +2
    -2
      fastNLP/io/__init__.py
  2. +5
    -5
      fastNLP/io/base_loader.py
  3. +2
    -2
      fastNLP/io/data_loader/imdb.py
  4. +3
    -3
      fastNLP/io/data_loader/matching.py
  5. +2
    -2
      fastNLP/io/data_loader/mtl.py
  6. +3
    -3
      fastNLP/io/data_loader/sst.py
  7. +2
    -2
      fastNLP/io/data_loader/yelp.py
  8. +3
    -3
      reproduction/Summarization/Baseline/data/dataloader.py
  9. +3
    -3
      reproduction/Summarization/BertSum/dataloader.py
  10. +2
    -2
      reproduction/coreference_resolution/data_load/cr_loader.py
  11. +2
    -2
      reproduction/joint_cws_parse/data/data_loader.py
  12. +3
    -3
      reproduction/matching/data/MatchingDataLoader.py
  13. +2
    -2
      reproduction/seqence_labelling/cws/data/CWSDataLoader.py
  14. +2
    -2
      reproduction/seqence_labelling/ner/data/Conll2003Loader.py
  15. +3
    -3
      reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
  16. +2
    -2
      reproduction/text_classification/data/IMDBLoader.py
  17. +2
    -2
      reproduction/text_classification/data/MTL16Loader.py
  18. +3
    -3
      reproduction/text_classification/data/sstloader.py
  19. +2
    -2
      reproduction/text_classification/data/yelpLoader.py

+ 2
- 2
fastNLP/io/__init__.py View File

@@ -12,7 +12,7 @@
__all__ = [
'EmbedLoader',

'DataInfo',
'DataBundle',
'DataSetLoader',

'CSVLoader',
@@ -35,7 +35,7 @@ __all__ = [
]

from .embed_loader import EmbedLoader
from .base_loader import DataInfo, DataSetLoader
from .base_loader import DataBundle, DataSetLoader
from .dataset_loader import CSVLoader, JsonLoader
from .model_io import ModelLoader, ModelSaver



+ 5
- 5
fastNLP/io/base_loader.py View File

@@ -1,6 +1,6 @@
__all__ = [
"BaseLoader",
'DataInfo',
'DataBundle',
'DataSetLoader',
]

@@ -109,7 +109,7 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src))


class DataInfo:
class DataBundle:
"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。

@@ -201,20 +201,20 @@ class DataSetLoader:
"""
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle:
"""
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。

从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。

返回的 :class:`DataInfo` 对象有如下属性:
返回的 :class:`DataBundle` 对象有如下属性:

- vocabs: 由从数据集中获取的词表组成的字典,每个词表
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`

:param paths: 原始数据读取的路径
:param options: 根据不同的任务和数据集,设计自己的参数
:return: 返回一个 DataInfo
:return: 返回一个 DataBundle
"""
raise NotImplementedError

+ 2
- 2
fastNLP/io/data_loader/imdb.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict

from ..embed_loader import EmbeddingOption, EmbedLoader
from ..base_loader import DataSetLoader, DataInfo
from ..base_loader import DataSetLoader, DataBundle
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.instance import Instance
@@ -48,7 +48,7 @@ class IMDBLoader(DataSetLoader):
char_level_op=False):

datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset


+ 3
- 3
fastNLP/io/data_loader/matching.py View File

@@ -4,7 +4,7 @@ from typing import Union, Dict , List

from ...core.const import Const
from ...core.vocabulary import Vocabulary
from ..base_loader import DataInfo, DataSetLoader
from ..base_loader import DataBundle, DataSetLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...modules.encoder._bert import BertTokenizer

@@ -34,7 +34,7 @@ class MatchingLoader(DataSetLoader):
cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None,
extra_split: List[str]=['-'], ) -> DataInfo:
extra_split: List[str]=['-'], ) -> DataBundle:
"""
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
@@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader):
else:
path = paths

data_info = DataInfo()
data_info = DataBundle()
for data_name in path.keys():
data_info.datasets[data_name] = self._load(path[data_name])



+ 2
- 2
fastNLP/io/data_loader/mtl.py View File

@@ -1,7 +1,7 @@

from typing import Union, Dict

from ..base_loader import DataInfo
from ..base_loader import DataBundle
from ..dataset_loader import CSVLoader
from ...core.vocabulary import Vocabulary, VocabularyOption
from ...core.const import Const
@@ -37,7 +37,7 @@ class MTL16Loader(CSVLoader):

paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset


+ 3
- 3
fastNLP/io/data_loader/sst.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict
from nltk import Tree

from ..base_loader import DataInfo, DataSetLoader
from ..base_loader import DataBundle, DataSetLoader
from ..dataset_loader import CSVLoader
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
@@ -73,7 +73,7 @@ class SSTLoader(DataSetLoader):
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo()
info = DataBundle()
origin_subtree = self.subtree
self.subtree = train_subtree
info.datasets['train'] = self._load(paths['train'])
@@ -126,7 +126,7 @@ class SST2Loader(CSVLoader):

paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset


+ 2
- 2
fastNLP/io/data_loader/yelp.py View File

@@ -6,7 +6,7 @@ from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.vocabulary import VocabularyOption, Vocabulary
from ..base_loader import DataInfo, DataSetLoader
from ..base_loader import DataBundle, DataSetLoader
from typing import Union, Dict
from ..utils import check_dataloader_paths, get_tokenizer

@@ -58,7 +58,7 @@ class YelpLoader(DataSetLoader):
tgt_vocab_op: VocabularyOption = None,
char_level_op=False):
paths = check_dataloader_paths(paths)
info = DataInfo(datasets=self.load(paths))
info = DataBundle(datasets=self.load(paths))
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)


+ 3
- 3
reproduction/Summarization/Baseline/data/dataloader.py View File

@@ -2,7 +2,7 @@ import pickle
import numpy as np
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.core.const import Const
@@ -66,7 +66,7 @@ class SummarizationLoader(JsonLoader):
:param domain: bool build vocab for publication, use 'X' for unknown
:param tag: bool build vocab for tag, use 'X' for unknown
:param load_vocab: bool build vocab (False) or load vocab (True)
:return: DataInfo
:return: DataBundle
datasets: dict keys correspond to the paths dict
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
embeddings: optional
@@ -182,7 +182,7 @@ class SummarizationLoader(JsonLoader):
for ds in datasets.values():
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return DataInfo(vocabs=vocab_dict, datasets=datasets)
return DataBundle(vocabs=vocab_dict, datasets=datasets)

+ 3
- 3
reproduction/Summarization/BertSum/dataloader.py View File

@@ -3,7 +3,7 @@ from datetime import timedelta

from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.modules.encoder._bert import BertTokenizer
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.base_loader import DataBundle
from fastNLP.core.const import Const

class BertData(JsonLoader):
@@ -110,7 +110,7 @@ class BertData(JsonLoader):
# set paddding value
datasets[name].set_pad_val('article', 0)

return DataInfo(datasets=datasets)
return DataBundle(datasets=datasets)


class BertSumLoader(JsonLoader):
@@ -154,4 +154,4 @@ class BertSumLoader(JsonLoader):

print('Finished in {}'.format(timedelta(seconds=time()-start)))

return DataInfo(datasets=datasets)
return DataBundle(datasets=datasets)

+ 2
- 2
reproduction/coreference_resolution/data_load/cr_loader.py View File

@@ -1,7 +1,7 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.base_loader import DataBundle
from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess

@@ -26,7 +26,7 @@ class CRLoader(JsonLoader):
return dataset

def process(self, paths, **kwargs):
data_info = DataInfo()
data_info = DataBundle()
for name in ['train', 'test', 'dev']:
data_info.datasets[name] = self.load(paths[name])



+ 2
- 2
reproduction/joint_cws_parse/data/data_loader.py View File

@@ -1,6 +1,6 @@


from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_loader import ConllLoader
import numpy as np

@@ -76,7 +76,7 @@ class CTBxJointLoader(DataSetLoader):
gold_label_word_pairs:
"""
paths = check_dataloader_paths(paths)
data = DataInfo()
data = DataBundle()

for name, path in paths.items():
dataset = self.load(path)


+ 3
- 3
reproduction/matching/data/MatchingDataLoader.py View File

@@ -5,7 +5,7 @@ from typing import Union, Dict

from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo, DataSetLoader
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer
@@ -35,7 +35,7 @@ class MatchingLoader(DataSetLoader):
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle:
"""
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
@@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader):
else:
path = paths

data_info = DataInfo()
data_info = DataBundle()
for data_name in path.keys():
data_info.datasets[data_name] = self._load(path[data_name])



+ 2
- 2
reproduction/seqence_labelling/cws/data/CWSDataLoader.py View File

@@ -1,7 +1,7 @@

from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance
@@ -161,7 +161,7 @@ class SigHanLoader(DataSetLoader):
# 推荐大家使用这个check_data_loader_paths进行paths的验证
paths = check_dataloader_paths(paths)
datasets = {}
data = DataInfo()
data = DataBundle()
bigram = bigram_vocab_opt is not None
for name, path in paths.items():
dataset = self.load(path, bigram=bigram)


+ 2
- 2
reproduction/seqence_labelling/ner/data/Conll2003Loader.py View File

@@ -1,6 +1,6 @@

from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict
from fastNLP import Vocabulary
from fastNLP import Const
@@ -51,7 +51,7 @@ class Conll2003DataLoader(DataSetLoader):
"""
# 读取数据
paths = check_dataloader_paths(paths)
data = DataInfo()
data = DataBundle()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items():


+ 3
- 3
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py View File

@@ -1,5 +1,5 @@
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict
from fastNLP import DataSet
from fastNLP import Vocabulary
@@ -76,7 +76,7 @@ class OntoNoteNERDataLoader(DataSetLoader):
return dataset

def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
lower:bool=True)->DataInfo:
lower:bool=True)->DataBundle:
"""
读取并处理数据。返回的DataInfo包含以下的内容
vocabs:
@@ -96,7 +96,7 @@ class OntoNoteNERDataLoader(DataSetLoader):
:return:
"""
paths = check_dataloader_paths(paths)
data = DataInfo()
data = DataBundle()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items():


+ 2
- 2
reproduction/text_classification/data/IMDBLoader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance
@@ -50,7 +50,7 @@ class IMDBLoader(DataSetLoader):
char_level_op=False):
datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset


+ 2
- 2
reproduction/text_classification/data/MTL16Loader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance
@@ -47,7 +47,7 @@ class MTL16Loader(DataSetLoader):
paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset


+ 3
- 3
reproduction/text_classification/data/sstloader.py View File

@@ -1,6 +1,6 @@
from typing import Iterable
from nltk import Tree
from fastNLP.io.base_loader import DataInfo, DataSetLoader
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
from fastNLP import DataSet
from fastNLP import Instance
@@ -68,7 +68,7 @@ class SSTLoader(DataSetLoader):
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo(datasets=self.load(paths))
info = DataBundle(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()
src_vocab.from_dataset(*_train_ds, field_name=input_name)
@@ -134,7 +134,7 @@ class sst2Loader(DataSetLoader):

paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset


+ 2
- 2
reproduction/text_classification/data/yelpLoader.py View File

@@ -4,7 +4,7 @@ from typing import Iterable
from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io import JsonLoader
from fastNLP.io.base_loader import DataInfo,DataSetLoader
from fastNLP.io.base_loader import DataBundle,DataSetLoader
from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json
from typing import Union, Dict
@@ -134,7 +134,7 @@ class yelpLoader(DataSetLoader):
char_level_op=False):
paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo(datasets=self.load(paths))
info = DataBundle(datasets=self.load(paths))
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)


Loading…
Cancel
Save