Browse Source

将DataInfo修改为DataBundle

tags/v0.4.10
yh 6 years ago
parent
commit
5c80c6f4f5
19 changed files with 48 additions and 48 deletions
  1. +2
    -2
      fastNLP/io/__init__.py
  2. +5
    -5
      fastNLP/io/base_loader.py
  3. +2
    -2
      fastNLP/io/data_loader/imdb.py
  4. +3
    -3
      fastNLP/io/data_loader/matching.py
  5. +2
    -2
      fastNLP/io/data_loader/mtl.py
  6. +3
    -3
      fastNLP/io/data_loader/sst.py
  7. +2
    -2
      fastNLP/io/data_loader/yelp.py
  8. +3
    -3
      reproduction/Summarization/Baseline/data/dataloader.py
  9. +3
    -3
      reproduction/Summarization/BertSum/dataloader.py
  10. +2
    -2
      reproduction/coreference_resolution/data_load/cr_loader.py
  11. +2
    -2
      reproduction/joint_cws_parse/data/data_loader.py
  12. +3
    -3
      reproduction/matching/data/MatchingDataLoader.py
  13. +2
    -2
      reproduction/seqence_labelling/cws/data/CWSDataLoader.py
  14. +2
    -2
      reproduction/seqence_labelling/ner/data/Conll2003Loader.py
  15. +3
    -3
      reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
  16. +2
    -2
      reproduction/text_classification/data/IMDBLoader.py
  17. +2
    -2
      reproduction/text_classification/data/MTL16Loader.py
  18. +3
    -3
      reproduction/text_classification/data/sstloader.py
  19. +2
    -2
      reproduction/text_classification/data/yelpLoader.py

+ 2
- 2
fastNLP/io/__init__.py View File

@@ -12,7 +12,7 @@
__all__ = [ __all__ = [
'EmbedLoader', 'EmbedLoader',


'DataInfo',
'DataBundle',
'DataSetLoader', 'DataSetLoader',


'CSVLoader', 'CSVLoader',
@@ -35,7 +35,7 @@ __all__ = [
] ]


from .embed_loader import EmbedLoader from .embed_loader import EmbedLoader
from .base_loader import DataInfo, DataSetLoader
from .base_loader import DataBundle, DataSetLoader
from .dataset_loader import CSVLoader, JsonLoader from .dataset_loader import CSVLoader, JsonLoader
from .model_io import ModelLoader, ModelSaver from .model_io import ModelLoader, ModelSaver




+ 5
- 5
fastNLP/io/base_loader.py View File

@@ -1,6 +1,6 @@
__all__ = [ __all__ = [
"BaseLoader", "BaseLoader",
'DataInfo',
'DataBundle',
'DataSetLoader', 'DataSetLoader',
] ]


@@ -109,7 +109,7 @@ def _uncompress(src, dst):
raise ValueError('unsupported file {}'.format(src)) raise ValueError('unsupported file {}'.format(src))




class DataInfo:
class DataBundle:
""" """
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。


@@ -201,20 +201,20 @@ class DataSetLoader:
""" """
raise NotImplementedError raise NotImplementedError


def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle:
""" """
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。


从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。


返回的 :class:`DataInfo` 对象有如下属性:
返回的 :class:`DataBundle` 对象有如下属性:


- vocabs: 由从数据集中获取的词表组成的字典,每个词表 - vocabs: 由从数据集中获取的词表组成的字典,每个词表
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`


:param paths: 原始数据读取的路径 :param paths: 原始数据读取的路径
:param options: 根据不同的任务和数据集,设计自己的参数 :param options: 根据不同的任务和数据集,设计自己的参数
:return: 返回一个 DataInfo
:return: 返回一个 DataBundle
""" """
raise NotImplementedError raise NotImplementedError

+ 2
- 2
fastNLP/io/data_loader/imdb.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict from typing import Union, Dict


from ..embed_loader import EmbeddingOption, EmbedLoader from ..embed_loader import EmbeddingOption, EmbedLoader
from ..base_loader import DataSetLoader, DataInfo
from ..base_loader import DataSetLoader, DataBundle
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
@@ -48,7 +48,7 @@ class IMDBLoader(DataSetLoader):
char_level_op=False): char_level_op=False):


datasets = {} datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path) dataset = self.load(path)
datasets[name] = dataset datasets[name] = dataset


+ 3
- 3
fastNLP/io/data_loader/matching.py View File

@@ -4,7 +4,7 @@ from typing import Union, Dict , List


from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ..base_loader import DataInfo, DataSetLoader
from ..base_loader import DataBundle, DataSetLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...modules.encoder._bert import BertTokenizer from ...modules.encoder._bert import BertTokenizer


@@ -34,7 +34,7 @@ class MatchingLoader(DataSetLoader):
cut_text: int = None, get_index=True, auto_pad_length: int=None, cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None,
extra_split: List[str]=['-'], ) -> DataInfo:
extra_split: List[str]=['-'], ) -> DataBundle:
""" """
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
@@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader):
else: else:
path = paths path = paths


data_info = DataInfo()
data_info = DataBundle()
for data_name in path.keys(): for data_name in path.keys():
data_info.datasets[data_name] = self._load(path[data_name]) data_info.datasets[data_name] = self._load(path[data_name])




+ 2
- 2
fastNLP/io/data_loader/mtl.py View File

@@ -1,7 +1,7 @@


from typing import Union, Dict from typing import Union, Dict


from ..base_loader import DataInfo
from ..base_loader import DataBundle
from ..dataset_loader import CSVLoader from ..dataset_loader import CSVLoader
from ...core.vocabulary import Vocabulary, VocabularyOption from ...core.vocabulary import Vocabulary, VocabularyOption
from ...core.const import Const from ...core.const import Const
@@ -37,7 +37,7 @@ class MTL16Loader(CSVLoader):


paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path) dataset = self.load(path)
datasets[name] = dataset datasets[name] = dataset


+ 3
- 3
fastNLP/io/data_loader/sst.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict from typing import Union, Dict
from nltk import Tree from nltk import Tree


from ..base_loader import DataInfo, DataSetLoader
from ..base_loader import DataBundle, DataSetLoader
from ..dataset_loader import CSVLoader from ..dataset_loader import CSVLoader
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet from ...core.dataset import DataSet
@@ -73,7 +73,7 @@ class SSTLoader(DataSetLoader):
tgt_vocab = Vocabulary(unknown=None, padding=None) \ tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)


info = DataInfo()
info = DataBundle()
origin_subtree = self.subtree origin_subtree = self.subtree
self.subtree = train_subtree self.subtree = train_subtree
info.datasets['train'] = self._load(paths['train']) info.datasets['train'] = self._load(paths['train'])
@@ -126,7 +126,7 @@ class SST2Loader(CSVLoader):


paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path) dataset = self.load(path)
datasets[name] = dataset datasets[name] = dataset


+ 2
- 2
fastNLP/io/data_loader/yelp.py View File

@@ -6,7 +6,7 @@ from ...core.const import Const
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ..base_loader import DataInfo, DataSetLoader
from ..base_loader import DataBundle, DataSetLoader
from typing import Union, Dict from typing import Union, Dict
from ..utils import check_dataloader_paths, get_tokenizer from ..utils import check_dataloader_paths, get_tokenizer


@@ -58,7 +58,7 @@ class YelpLoader(DataSetLoader):
tgt_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None,
char_level_op=False): char_level_op=False):
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
info = DataInfo(datasets=self.load(paths))
info = DataBundle(datasets=self.load(paths))
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \ tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)


+ 3
- 3
reproduction/Summarization/Baseline/data/dataloader.py View File

@@ -2,7 +2,7 @@ import pickle
import numpy as np import numpy as np
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.dataset_loader import JsonLoader from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.core.const import Const from fastNLP.core.const import Const
@@ -66,7 +66,7 @@ class SummarizationLoader(JsonLoader):
:param domain: bool build vocab for publication, use 'X' for unknown :param domain: bool build vocab for publication, use 'X' for unknown
:param tag: bool build vocab for tag, use 'X' for unknown :param tag: bool build vocab for tag, use 'X' for unknown
:param load_vocab: bool build vocab (False) or load vocab (True) :param load_vocab: bool build vocab (False) or load vocab (True)
:return: DataInfo
:return: DataBundle
datasets: dict keys correspond to the paths dict datasets: dict keys correspond to the paths dict
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
embeddings: optional embeddings: optional
@@ -182,7 +182,7 @@ class SummarizationLoader(JsonLoader):
for ds in datasets.values(): for ds in datasets.values():
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return DataInfo(vocabs=vocab_dict, datasets=datasets)
return DataBundle(vocabs=vocab_dict, datasets=datasets)

+ 3
- 3
reproduction/Summarization/BertSum/dataloader.py View File

@@ -3,7 +3,7 @@ from datetime import timedelta


from fastNLP.io.dataset_loader import JsonLoader from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.modules.encoder._bert import BertTokenizer from fastNLP.modules.encoder._bert import BertTokenizer
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.base_loader import DataBundle
from fastNLP.core.const import Const from fastNLP.core.const import Const


class BertData(JsonLoader): class BertData(JsonLoader):
@@ -110,7 +110,7 @@ class BertData(JsonLoader):
# set paddding value # set paddding value
datasets[name].set_pad_val('article', 0) datasets[name].set_pad_val('article', 0)


return DataInfo(datasets=datasets)
return DataBundle(datasets=datasets)




class BertSumLoader(JsonLoader): class BertSumLoader(JsonLoader):
@@ -154,4 +154,4 @@ class BertSumLoader(JsonLoader):


print('Finished in {}'.format(timedelta(seconds=time()-start))) print('Finished in {}'.format(timedelta(seconds=time()-start)))


return DataInfo(datasets=datasets)
return DataBundle(datasets=datasets)

+ 2
- 2
reproduction/coreference_resolution/data_load/cr_loader.py View File

@@ -1,7 +1,7 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.base_loader import DataBundle
from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess import reproduction.coreference_resolution.model.preprocess as preprocess


@@ -26,7 +26,7 @@ class CRLoader(JsonLoader):
return dataset return dataset


def process(self, paths, **kwargs): def process(self, paths, **kwargs):
data_info = DataInfo()
data_info = DataBundle()
for name in ['train', 'test', 'dev']: for name in ['train', 'test', 'dev']:
data_info.datasets[name] = self.load(paths[name]) data_info.datasets[name] = self.load(paths[name])




+ 2
- 2
reproduction/joint_cws_parse/data/data_loader.py View File

@@ -1,6 +1,6 @@




from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_loader import ConllLoader from fastNLP.io.data_loader import ConllLoader
import numpy as np import numpy as np


@@ -76,7 +76,7 @@ class CTBxJointLoader(DataSetLoader):
gold_label_word_pairs: gold_label_word_pairs:
""" """
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
data = DataInfo()
data = DataBundle()


for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path) dataset = self.load(path)


+ 3
- 3
reproduction/matching/data/MatchingDataLoader.py View File

@@ -5,7 +5,7 @@ from typing import Union, Dict


from fastNLP.core.const import Const from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo, DataSetLoader
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer from fastNLP.modules.encoder._bert import BertTokenizer
@@ -35,7 +35,7 @@ class MatchingLoader(DataSetLoader):
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
cut_text: int = None, get_index=True, auto_pad_length: int=None, cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle:
""" """
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
@@ -80,7 +80,7 @@ class MatchingLoader(DataSetLoader):
else: else:
path = paths path = paths


data_info = DataInfo()
data_info = DataBundle()
for data_name in path.keys(): for data_name in path.keys():
data_info.datasets[data_name] = self._load(path[data_name]) data_info.datasets[data_name] = self._load(path[data_name])




+ 2
- 2
reproduction/seqence_labelling/cws/data/CWSDataLoader.py View File

@@ -1,7 +1,7 @@


from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator from typing import Union, Dict, List, Iterator
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
@@ -161,7 +161,7 @@ class SigHanLoader(DataSetLoader):
# 推荐大家使用这个check_data_loader_paths进行paths的验证 # 推荐大家使用这个check_data_loader_paths进行paths的验证
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
data = DataInfo()
data = DataBundle()
bigram = bigram_vocab_opt is not None bigram = bigram_vocab_opt is not None
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path, bigram=bigram) dataset = self.load(path, bigram=bigram)


+ 2
- 2
reproduction/seqence_labelling/ner/data/Conll2003Loader.py View File

@@ -1,6 +1,6 @@


from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict from typing import Union, Dict
from fastNLP import Vocabulary from fastNLP import Vocabulary
from fastNLP import Const from fastNLP import Const
@@ -51,7 +51,7 @@ class Conll2003DataLoader(DataSetLoader):
""" """
# 读取数据 # 读取数据
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
data = DataInfo()
data = DataBundle()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items(): for name, path in paths.items():


+ 3
- 3
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py View File

@@ -1,5 +1,5 @@
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict from typing import Union, Dict
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Vocabulary from fastNLP import Vocabulary
@@ -76,7 +76,7 @@ class OntoNoteNERDataLoader(DataSetLoader):
return dataset return dataset


def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
lower:bool=True)->DataInfo:
lower:bool=True)->DataBundle:
""" """
读取并处理数据。返回的DataInfo包含以下的内容 读取并处理数据。返回的DataInfo包含以下的内容
vocabs: vocabs:
@@ -96,7 +96,7 @@ class OntoNoteNERDataLoader(DataSetLoader):
:return: :return:
""" """
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
data = DataInfo()
data = DataBundle()
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, path in paths.items(): for name, path in paths.items():


+ 2
- 2
reproduction/text_classification/data/IMDBLoader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator from typing import Union, Dict, List, Iterator
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
@@ -50,7 +50,7 @@ class IMDBLoader(DataSetLoader):
char_level_op=False): char_level_op=False):
datasets = {} datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path) dataset = self.load(path)
datasets[name] = dataset datasets[name] = dataset


+ 2
- 2
reproduction/text_classification/data/MTL16Loader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator from typing import Union, Dict, List, Iterator
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
@@ -47,7 +47,7 @@ class MTL16Loader(DataSetLoader):
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path) dataset = self.load(path)
datasets[name] = dataset datasets[name] = dataset


+ 3
- 3
reproduction/text_classification/data/sstloader.py View File

@@ -1,6 +1,6 @@
from typing import Iterable from typing import Iterable
from nltk import Tree from nltk import Tree
from fastNLP.io.base_loader import DataInfo, DataSetLoader
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
@@ -68,7 +68,7 @@ class SSTLoader(DataSetLoader):
tgt_vocab = Vocabulary(unknown=None, padding=None) \ tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)


info = DataInfo(datasets=self.load(paths))
info = DataBundle(datasets=self.load(paths))
_train_ds = [info.datasets[name] _train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values() for name in train_ds] if train_ds else info.datasets.values()
src_vocab.from_dataset(*_train_ds, field_name=input_name) src_vocab.from_dataset(*_train_ds, field_name=input_name)
@@ -134,7 +134,7 @@ class sst2Loader(DataSetLoader):


paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
info = DataInfo()
info = DataBundle()
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path) dataset = self.load(path)
datasets[name] = dataset datasets[name] = dataset


+ 2
- 2
reproduction/text_classification/data/yelpLoader.py View File

@@ -4,7 +4,7 @@ from typing import Iterable
from fastNLP import DataSet, Instance, Vocabulary from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io import JsonLoader from fastNLP.io import JsonLoader
from fastNLP.io.base_loader import DataInfo,DataSetLoader
from fastNLP.io.base_loader import DataBundle,DataSetLoader
from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json from fastNLP.io.file_reader import _read_json
from typing import Union, Dict from typing import Union, Dict
@@ -134,7 +134,7 @@ class yelpLoader(DataSetLoader):
char_level_op=False): char_level_op=False):
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
info = DataInfo(datasets=self.load(paths))
info = DataBundle(datasets=self.load(paths))
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \ tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)


Loading…
Cancel
Save