diff --git a/docs/Makefile b/docs/Makefile index b9f1cf95..b41beb44 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -20,7 +20,7 @@ server: cd build/html && python -m http.server dev: - rm -rf build/html && make html && make server + rm -rf build && make html && make server .PHONY: help Makefile diff --git a/docs/format.py b/docs/format.py index 7cc341c2..67671ae7 100644 --- a/docs/format.py +++ b/docs/format.py @@ -59,7 +59,10 @@ def clear(path='./source/'): else: shorten(path + file, to_delete) for file in to_delete: - os.remove(path + file + ".rst") + try: + os.remove(path + file + ".rst") + except: + pass clear() diff --git a/docs/source/fastNLP.io.base_loader.rst b/docs/source/fastNLP.io.base_loader.rst deleted file mode 100644 index 057867f4..00000000 --- a/docs/source/fastNLP.io.base_loader.rst +++ /dev/null @@ -1,7 +0,0 @@ -fastNLP.io.base\_loader -======================= - -.. automodule:: fastNLP.io.base_loader - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/fastNLP.io.data_bundle.rst b/docs/source/fastNLP.io.data_bundle.rst new file mode 100644 index 00000000..a6273956 --- /dev/null +++ b/docs/source/fastNLP.io.data_bundle.rst @@ -0,0 +1,7 @@ +fastNLP.io.data\_bundle +======================= + +.. automodule:: fastNLP.io.data_bundle + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index 0a006709..0cd5d3f2 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -20,7 +20,7 @@ Submodules .. toctree:: - fastNLP.io.base_loader + fastNLP.io.data_bundle fastNLP.io.dataset_loader fastNLP.io.embed_loader fastNLP.io.file_utils diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 5234b209..90d4d12c 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -12,10 +12,9 @@ 这些类的使用方法如下: """ __all__ = [ - 'EmbedLoader', - 'DataBundle', - 'DataSetLoader', + + 'EmbedLoader', 'YelpLoader', 'YelpFullLoader', @@ -69,7 +68,7 @@ __all__ = [ ] from .embed_loader import EmbedLoader -from .base_loader import DataBundle, DataSetLoader +from .data_bundle import DataBundle from .dataset_loader import CSVLoader, JsonLoader from .model_io import ModelLoader, ModelSaver diff --git a/fastNLP/io/config_io.py b/fastNLP/io/config_io.py deleted file mode 100644 index ac349080..00000000 --- a/fastNLP/io/config_io.py +++ /dev/null @@ -1,313 +0,0 @@ -""" -用于读入和处理和保存 config 文件 - -.. todo:: - 这个模块中的类可能被抛弃? - -""" -__all__ = [ - "ConfigLoader", - "ConfigSection", - "ConfigSaver" -] - -import configparser -import json -import os - -from .base_loader import BaseLoader - - -class ConfigLoader(BaseLoader): - """ - 别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader` - - 读取配置文件的Loader - - :param str data_path: 配置文件的路径 - - """ - - def __init__(self, data_path=None): - super(ConfigLoader, self).__init__() - if data_path is not None: - self.config = self.parse(super(ConfigLoader, self).load(data_path)) - - @staticmethod - def parse(string): - raise NotImplementedError - - @staticmethod - def load_config(file_path, sections): - """ - 把配置文件的section 存入提供的 ``sections`` 中 - - :param str file_path: 配置文件的路径 - :param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection` - - Example:: - - test_args = ConfigSection() - ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) - - """ - assert isinstance(sections, dict) - cfg = configparser.ConfigParser() - if not os.path.exists(file_path): - raise FileNotFoundError("config file {} not found. ".format(file_path)) - cfg.read(file_path) - for s in sections: - attr_list = [i for i in sections[s].__dict__.keys() if - not callable(getattr(sections[s], i)) and not i.startswith("__")] - if s not in cfg: - print('section %s not found in config file' % (s)) - continue - gen_sec = cfg[s] - for attr in gen_sec.keys(): - try: - val = json.loads(gen_sec[attr]) - # print(s, attr, val, type(val)) - if attr in attr_list: - assert type(val) == type(getattr(sections[s], attr)), \ - 'type not match, except %s but got %s' % \ - (type(getattr(sections[s], attr)), type(val)) - """ - if attr in attr_list then check its type and - update its value. - else add a new attr in sections[s] - """ - setattr(sections[s], attr, val) - except Exception as e: - print("cannot load attribute %s in section %s" - % (attr, s)) - pass - - -class ConfigSection(object): - """ - 别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection` - - ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 - - """ - - def __init__(self): - super(ConfigSection, self).__init__() - - def __getitem__(self, key): - """ - :param key: str, the name of the attribute - :return attr: the value of this attribute - if key not in self.__dict__.keys(): - return self[key] - else: - raise AttributeError - """ - if key in self.__dict__.keys(): - return getattr(self, key) - raise AttributeError("do NOT have attribute %s" % key) - - def __setitem__(self, key, value): - """ - :param key: str, the name of the attribute - :param value: the value of this attribute - if key not in self.__dict__.keys(): - self[key] will be added - else: - self[key] will be updated - """ - if key in self.__dict__.keys(): - if not isinstance(value, type(getattr(self, key))): - raise AttributeError("attr %s except %s but got %s" % - (key, str(type(getattr(self, key))), str(type(value)))) - setattr(self, key, value) - - def __contains__(self, item): - """ - :param item: The key of item. - :return: True if the key in self.__dict__.keys() else False. - """ - return item in self.__dict__.keys() - - def __eq__(self, other): - """Overwrite the == operator - - :param other: Another ConfigSection() object which to be compared. - :return: True if value of each key in each ConfigSection() object are equal to the other, else False. - """ - for k in self.__dict__.keys(): - if k not in other.__dict__.keys(): - return False - if getattr(self, k) != getattr(self, k): - return False - - for k in other.__dict__.keys(): - if k not in self.__dict__.keys(): - return False - if getattr(self, k) != getattr(self, k): - return False - - return True - - def __ne__(self, other): - """Overwrite the != operator - - :param other: - :return: - """ - return not self.__eq__(other) - - @property - def data(self): - return self.__dict__ - - -class ConfigSaver(object): - """ - 别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver` - - ConfigSaver 是用来存储配置文件并解决相关冲突的类 - - :param str file_path: 配置文件的路径 - - """ - - def __init__(self, file_path): - self.file_path = file_path - if not os.path.exists(self.file_path): - raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) - - def _get_section(self, sect_name): - """ - This is the function to get the section with the section name. - - :param sect_name: The name of section what wants to load. - :return: The section. - """ - sect = ConfigSection() - ConfigLoader().load_config(self.file_path, {sect_name: sect}) - return sect - - def _read_section(self): - """ - This is the function to read sections from the config file. - - :return: sect_list, sect_key_list - sect_list: A list of ConfigSection(). - sect_key_list: A list of names in sect_list. - """ - sect_name = None - - sect_list = {} - sect_key_list = [] - - single_section = {} - single_section_key = [] - - with open(self.file_path, 'r') as f: - lines = f.readlines() - - for line in lines: - if line.startswith('[') and line.endswith(']\n'): - if sect_name is None: - pass - else: - sect_list[sect_name] = single_section, single_section_key - single_section = {} - single_section_key = [] - sect_key_list.append(sect_name) - sect_name = line[1: -2] - continue - - if line.startswith('#'): - single_section[line] = '#' - single_section_key.append(line) - continue - - if line.startswith('\n'): - single_section_key.append('\n') - continue - - if '=' not in line: - raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) - - key = line.split('=', maxsplit=1)[0].strip() - value = line.split('=', maxsplit=1)[1].strip() + '\n' - single_section[key] = value - single_section_key.append(key) - - if sect_name is not None: - sect_list[sect_name] = single_section, single_section_key - sect_key_list.append(sect_name) - return sect_list, sect_key_list - - def _write_section(self, sect_list, sect_key_list): - """ - This is the function to write config file with section list and name list. - - :param sect_list: A list of ConfigSection() need to be writen into file. - :param sect_key_list: A list of name of sect_list. - :return: - """ - with open(self.file_path, 'w') as f: - for sect_key in sect_key_list: - single_section, single_section_key = sect_list[sect_key] - f.write('[' + sect_key + ']\n') - for key in single_section_key: - if key == '\n': - f.write('\n') - continue - if single_section[key] == '#': - f.write(key) - continue - f.write(key + ' = ' + single_section[key]) - f.write('\n') - - def save_config_file(self, section_name, section): - """ - 这个方法可以用来修改并保存配置文件中单独的一个 section - - :param str section_name: 需要保存的 section 的名字. - :param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型 - """ - section_file = self._get_section(section_name) - if len(section_file.__dict__.keys()) == 0: # the section not in the file before - # append this section to config file - with open(self.file_path, 'a') as f: - f.write('[' + section_name + ']\n') - for k in section.__dict__.keys(): - f.write(k + ' = ') - if isinstance(section[k], str): - f.write('\"' + str(section[k]) + '\"\n\n') - else: - f.write(str(section[k]) + '\n\n') - else: - # the section exists - change_file = False - for k in section.__dict__.keys(): - if k not in section_file: - # find a new key in this section - change_file = True - break - if section_file[k] != section[k]: - change_file = True - break - if not change_file: - return - - sect_list, sect_key_list = self._read_section() - if section_name not in sect_key_list: - raise AttributeError() - - sect, sect_key = sect_list[section_name] - for k in section.__dict__.keys(): - if k not in sect_key: - if sect_key[-1] != '\n': - sect_key.append('\n') - sect_key.append(k) - sect[k] = str(section[k]) - if isinstance(section[k], str): - sect[k] = "\"" + sect[k] + "\"" - sect[k] = sect[k] + "\n" - sect_list[section_name] = sect, sect_key - self._write_section(sect_list, sect_key_list) diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/data_bundle.py similarity index 99% rename from fastNLP/io/base_loader.py rename to fastNLP/io/data_bundle.py index 5cbd5bb1..4203294b 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/data_bundle.py @@ -1,7 +1,5 @@ __all__ = [ - "BaseLoader", 'DataBundle', - 'DataSetLoader', ] import _pickle as pickle diff --git a/fastNLP/io/data_loader/conll.py b/fastNLP/io/data_loader/conll.py index 0285173c..7083b98d 100644 --- a/fastNLP/io/data_loader/conll.py +++ b/fastNLP/io/data_loader/conll.py @@ -1,11 +1,11 @@ from ...core.dataset import DataSet from ...core.instance import Instance -from ..base_loader import DataSetLoader +from ..data_bundle import DataSetLoader from ..file_reader import _read_conll from typing import Union, Dict from ..utils import check_loader_paths -from ..base_loader import DataBundle +from ..data_bundle import DataBundle class ConllLoader(DataSetLoader): """ diff --git a/fastNLP/io/data_loader/imdb.py b/fastNLP/io/data_loader/imdb.py index d3636cde..c9dda76e 100644 --- a/fastNLP/io/data_loader/imdb.py +++ b/fastNLP/io/data_loader/imdb.py @@ -2,7 +2,7 @@ from typing import Union, Dict from ..embed_loader import EmbeddingOption, EmbedLoader -from ..base_loader import DataSetLoader, DataBundle +from ..data_bundle import DataSetLoader, DataBundle from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.instance import Instance diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 1242b432..41c9a98d 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -4,7 +4,7 @@ from typing import Union, Dict, List from ...core.const import Const from ...core.vocabulary import Vocabulary -from ..base_loader import DataBundle, DataSetLoader +from ..data_bundle import DataBundle, DataSetLoader from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ...modules.encoder.bert import BertTokenizer diff --git a/fastNLP/io/data_loader/mtl.py b/fastNLP/io/data_loader/mtl.py index 20824958..923aadfb 100644 --- a/fastNLP/io/data_loader/mtl.py +++ b/fastNLP/io/data_loader/mtl.py @@ -1,7 +1,7 @@ from typing import Union, Dict -from ..base_loader import DataBundle +from ..data_bundle import DataBundle from ..dataset_loader import CSVLoader from ...core.vocabulary import Vocabulary, VocabularyOption from ...core.const import Const diff --git a/fastNLP/io/data_loader/people_daily.py b/fastNLP/io/data_loader/people_daily.py index 5efadb7d..afd66744 100644 --- a/fastNLP/io/data_loader/people_daily.py +++ b/fastNLP/io/data_loader/people_daily.py @@ -1,5 +1,5 @@ -from ..base_loader import DataSetLoader +from ..data_bundle import DataSetLoader from ...core.dataset import DataSet from ...core.instance import Instance from ...core.const import Const diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index c2e0eca1..2034fc2b 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -2,7 +2,7 @@ from typing import Union, Dict from nltk import Tree -from ..base_loader import DataBundle, DataSetLoader +from ..data_bundle import DataBundle, DataSetLoader from ..dataset_loader import CSVLoader from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet diff --git a/fastNLP/io/data_loader/yelp.py b/fastNLP/io/data_loader/yelp.py index 15533b04..f2bc60c8 100644 --- a/fastNLP/io/data_loader/yelp.py +++ b/fastNLP/io/data_loader/yelp.py @@ -6,7 +6,7 @@ from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance from ...core.vocabulary import VocabularyOption, Vocabulary -from ..base_loader import DataBundle, DataSetLoader +from ..data_bundle import DataBundle, DataSetLoader from typing import Union, Dict from ..utils import check_loader_paths, get_tokenizer diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index e1e06ec9..82e96597 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -26,7 +26,7 @@ __all__ = [ from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json -from .base_loader import DataSetLoader +from .data_bundle import DataSetLoader class JsonLoader(DataSetLoader): diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index 91a0919c..48048983 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -9,7 +9,7 @@ import warnings import numpy as np from ..core.vocabulary import Vocabulary -from .base_loader import BaseLoader +from .data_bundle import BaseLoader from ..core.utils import Option diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index a4e6a6f5..bcb3b730 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -44,6 +44,8 @@ fastNLP 目前提供了如下的 Loader """ __all__ = [ + 'Loader', + 'YelpLoader', 'YelpFullLoader', 'YelpPolarityLoader', @@ -57,7 +59,6 @@ __all__ = [ 'OntoNotesNERLoader', 'CTBLoader', - 'Loader', 'CSVLoader', 'JsonLoader', diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index dd85b4fe..ad56101d 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -7,6 +7,7 @@ import random import shutil import numpy as np + class YelpLoader(Loader): """ 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` @@ -14,6 +15,7 @@ class YelpLoader(Loader): 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 Example:: + "1","I got 'new' tires from the..." "1","Don't waste your time..." @@ -28,11 +30,11 @@ class YelpLoader(Loader): "...", "..." """ - + def __init__(self): super(YelpLoader, self).__init__() - - def _load(self, path: str=None): + + def _load(self, path: str = None): ds = DataSet() with open(path, 'r', encoding='utf-8') as f: for line in f: @@ -69,12 +71,12 @@ class YelpFullLoader(YelpLoader): :param int seed: 划分dev时的随机数种子 :return: str, 数据集的目录地址 """ - + dataset_name = 'yelp-review-full' data_dir = self._get_dataset_path(dataset_name=dataset_name) if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 re_download = True - if dev_ratio>0: + if dev_ratio > 0: dev_line_count = 0 tr_line_count = 0 with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ @@ -83,14 +85,14 @@ class YelpFullLoader(YelpLoader): tr_line_count += 1 for line in f2: dev_line_count += 1 - if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): re_download = True else: re_download = False if re_download: shutil.rmtree(data_dir) data_dir = self._get_dataset_path(dataset_name=dataset_name) - + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if dev_ratio > 0: assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." @@ -109,7 +111,7 @@ class YelpFullLoader(YelpLoader): finally: if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): os.remove(os.path.join(data_dir, 'middle_file.csv')) - + return data_dir @@ -131,7 +133,7 @@ class YelpPolarityLoader(YelpLoader): data_dir = self._get_dataset_path(dataset_name=dataset_name) if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 re_download = True - if dev_ratio>0: + if dev_ratio > 0: dev_line_count = 0 tr_line_count = 0 with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ @@ -140,14 +142,14 @@ class YelpPolarityLoader(YelpLoader): tr_line_count += 1 for line in f2: dev_line_count += 1 - if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): re_download = True else: re_download = False if re_download: shutil.rmtree(data_dir) data_dir = self._get_dataset_path(dataset_name=dataset_name) - + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if dev_ratio > 0: assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." @@ -166,7 +168,7 @@ class YelpPolarityLoader(YelpLoader): finally: if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): os.remove(os.path.join(data_dir, 'middle_file.csv')) - + return data_dir @@ -185,10 +187,10 @@ class IMDBLoader(Loader): "...", "..." """ - + def __init__(self): super(IMDBLoader, self).__init__() - + def _load(self, path: str): dataset = DataSet() with open(path, 'r', encoding="utf-8") as f: @@ -201,12 +203,12 @@ class IMDBLoader(Loader): words = parts[1] if words: dataset.append(Instance(raw_words=words, target=target)) - + if len(dataset) == 0: raise RuntimeError(f"{path} has no valid data.") - + return dataset - + def download(self, dev_ratio: float = 0.1, seed: int = 0): """ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 @@ -221,9 +223,9 @@ class IMDBLoader(Loader): """ dataset_name = 'aclImdb' data_dir = self._get_dataset_path(dataset_name=dataset_name) - if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 + if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 re_download = True - if dev_ratio>0: + if dev_ratio > 0: dev_line_count = 0 tr_line_count = 0 with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ @@ -232,14 +234,14 @@ class IMDBLoader(Loader): tr_line_count += 1 for line in f2: dev_line_count += 1 - if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): re_download = True else: re_download = False if re_download: shutil.rmtree(data_dir) data_dir = self._get_dataset_path(dataset_name=dataset_name) - + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if dev_ratio > 0: assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." @@ -258,7 +260,7 @@ class IMDBLoader(Loader): finally: if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): os.remove(os.path.join(data_dir, 'middle_file.txt')) - + return data_dir @@ -278,10 +280,10 @@ class SSTLoader(Loader): raw_words列是str。 """ - + def __init__(self): super().__init__() - + def _load(self, path: str): """ 从path读取SST文件 @@ -296,7 +298,7 @@ class SSTLoader(Loader): if line: ds.append(Instance(raw_words=line)) return ds - + def download(self): """ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 @@ -323,10 +325,10 @@ class SST2Loader(Loader): test的DataSet没有target列。 """ - + def __init__(self): super().__init__() - + def _load(self, path: str): """ 从path读取SST2文件 @@ -335,7 +337,7 @@ class SST2Loader(Loader): :return: DataSet """ ds = DataSet() - + with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if 'test' in os.path.split(path)[1]: @@ -356,7 +358,7 @@ class SST2Loader(Loader): if raw_words: ds.append(Instance(raw_words=raw_words, target=target)) return ds - + def download(self): """ 自动下载数据集,如果你使用了该数据集,请引用以下的文章 diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py index 607d6920..296714bf 100644 --- a/fastNLP/io/loader/loader.py +++ b/fastNLP/io/loader/loader.py @@ -2,17 +2,21 @@ from ...core.dataset import DataSet from .. import DataBundle from ..utils import check_loader_paths from typing import Union, Dict -import os from ..file_utils import _get_dataset_url, get_cache_path, cached_path + class Loader: + """ + 各种数据 Loader 的基类,提供了 API 的参考. + + """ def __init__(self): pass - - def _load(self, path:str) -> DataSet: + + def _load(self, path: str) -> DataSet: raise NotImplementedError - - def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: + + def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: """ 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 @@ -22,31 +26,25 @@ class Loader: (0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 - 名包含'train'、 'dev'、 'test'则会报错 - - Example:: + 名包含'train'、 'dev'、 'test'则会报错:: data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 # dev、 test等有所变化,可以通过以下的方式取出DataSet tr_data = data_bundle.datasets['train'] te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 - (2) 传入文件路径 - - Example:: + (2) 传入文件路径:: data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet - (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test - - Example:: + (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" dev_data = data_bundle.datasets['dev'] - :return: 返回的:class:`~fastNLP.io.DataBundle` + :return: 返回的 :class:`~fastNLP.io.DataBundle` """ if paths is None: paths = self.download() @@ -54,10 +52,10 @@ class Loader: datasets = {name: self._load(path) for name, path in paths.items()} data_bundle = DataBundle(datasets=datasets) return data_bundle - + def download(self): raise NotImplementedError(f"{self.__class__} cannot download data automatically.") - + def _get_dataset_path(self, dataset_name): """ 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 @@ -65,11 +63,9 @@ class Loader: :param str dataset_name: 数据集的名称 :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 """ - + default_cache_path = get_cache_path() url = _get_dataset_url(dataset_name) output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') - + return output_dir - - diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 58fa0d6f..26455914 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -203,7 +203,8 @@ class QNLILoader(JsonLoader): """ 如果您的实验使用到了该数据,请引用 - TODO 补充 + .. todo:: + 补充 :return: """ diff --git a/fastNLP/io/model_io.py b/fastNLP/io/model_io.py index ffaa4ef5..22ced1ce 100644 --- a/fastNLP/io/model_io.py +++ b/fastNLP/io/model_io.py @@ -8,7 +8,7 @@ __all__ = [ import torch -from .base_loader import BaseLoader +from .data_bundle import BaseLoader class ModelLoader(BaseLoader): diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 429b6552..daa17da9 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -1,6 +1,6 @@ from nltk import Tree -from ..base_loader import DataBundle +from ..data_bundle import DataBundle from ...core.vocabulary import Vocabulary from ...core.const import Const from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader diff --git a/reproduction/Summarization/Baseline/data/dataloader.py b/reproduction/Summarization/Baseline/data/dataloader.py index 47cd0856..dcb294b0 100644 --- a/reproduction/Summarization/Baseline/data/dataloader.py +++ b/reproduction/Summarization/Baseline/data/dataloader.py @@ -1,188 +1,188 @@ -import pickle -import numpy as np - -from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataBundle -from fastNLP.io.dataset_loader import JsonLoader -from fastNLP.core.const import Const - -from tools.logger import * - -WORD_PAD = "[PAD]" -WORD_UNK = "[UNK]" -DOMAIN_UNK = "X" -TAG_UNK = "X" - - -class SummarizationLoader(JsonLoader): - """ - 读取summarization数据集,读取的DataSet包含fields:: - - text: list(str),document - summary: list(str), summary - text_wd: list(list(str)),tokenized document - summary_wd: list(list(str)), tokenized summary - labels: list(int), - flatten_label: list(int), 0 or 1, flatten labels - domain: str, optional - tag: list(str), optional - - 数据来源: CNN_DailyMail Newsroom DUC - """ - - def __init__(self): - super(SummarizationLoader, self).__init__() - - def _load(self, path): - ds = super(SummarizationLoader, self)._load(path) - - def _lower_text(text_list): - return [text.lower() for text in text_list] - - def _split_list(text_list): - return [text.split() for text in text_list] - - def _convert_label(label, sent_len): - np_label = np.zeros(sent_len, dtype=int) - if label != []: - np_label[np.array(label)] = 1 - return np_label.tolist() - - ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') - ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') - ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') - ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') - ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") - - return ds - - def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): - """ - :param paths: dict path for each dataset - :param vocab_size: int max_size for vocab - :param vocab_path: str vocab path - :param sent_max_len: int max token number of the sentence - :param doc_max_timesteps: int max sentence number of the document - :param domain: bool build vocab for publication, use 'X' for unknown - :param tag: bool build vocab for tag, use 'X' for unknown - :param load_vocab_file: bool build vocab (False) or load vocab (True) - :return: DataBundle - datasets: dict keys correspond to the paths dict - vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) - embeddings: optional - """ - - def _pad_sent(text_wd): - pad_text_wd = [] - for sent_wd in text_wd: - if len(sent_wd) < sent_max_len: - pad_num = sent_max_len - len(sent_wd) - sent_wd.extend([WORD_PAD] * pad_num) - else: - sent_wd = sent_wd[:sent_max_len] - pad_text_wd.append(sent_wd) - return pad_text_wd - - def _token_mask(text_wd): - token_mask_list = [] - for sent_wd in text_wd: - token_num = len(sent_wd) - if token_num < sent_max_len: - mask = [1] * token_num + [0] * (sent_max_len - token_num) - else: - mask = [1] * sent_max_len - token_mask_list.append(mask) - return token_mask_list - - def _pad_label(label): - text_len = len(label) - if text_len < doc_max_timesteps: - pad_label = label + [0] * (doc_max_timesteps - text_len) - else: - pad_label = label[:doc_max_timesteps] - return pad_label - - def _pad_doc(text_wd): - text_len = len(text_wd) - if text_len < doc_max_timesteps: - padding = [WORD_PAD] * sent_max_len - pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) - else: - pad_text = text_wd[:doc_max_timesteps] - return pad_text - - def _sent_mask(text_wd): - text_len = len(text_wd) - if text_len < doc_max_timesteps: - sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) - else: - sent_mask = [1] * doc_max_timesteps - return sent_mask - - - datasets = {} - train_ds = None - for key, value in paths.items(): - ds = self.load(value) - # pad sent - ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") - ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") - # pad document - ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") - ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") - ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") - - # rename field - ds.rename_field("pad_text", Const.INPUT) - ds.rename_field("seq_len", Const.INPUT_LEN) - ds.rename_field("pad_label", Const.TARGET) - - # set input and target - ds.set_input(Const.INPUT, Const.INPUT_LEN) - ds.set_target(Const.TARGET, Const.INPUT_LEN) - - datasets[key] = ds - if "train" in key: - train_ds = datasets[key] - - vocab_dict = {} - if load_vocab_file == False: - logger.info("[INFO] Build new vocab from training dataset!") - if train_ds == None: - raise ValueError("Lack train file to build vocabulary!") - - vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) - vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) - vocab_dict["vocab"] = vocabs - else: - logger.info("[INFO] Load existing vocab from %s!" % vocab_path) - word_list = [] - with open(vocab_path, 'r', encoding='utf8') as vocab_f: - cnt = 2 # pad and unk - for line in vocab_f: - pieces = line.split("\t") - word_list.append(pieces[0]) - cnt += 1 - if cnt > vocab_size: - break - vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) - vocabs.add_word_lst(word_list) - vocabs.build_vocab() - vocab_dict["vocab"] = vocabs - - if domain == True: - domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) - domaindict.from_dataset(train_ds, field_name="publication") - vocab_dict["domain"] = domaindict - if tag == True: - tagdict = Vocabulary(padding=None, unknown=TAG_UNK) - tagdict.from_dataset(train_ds, field_name="tag") - vocab_dict["tag"] = tagdict - - for ds in datasets.values(): - vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) - - return DataBundle(vocabs=vocab_dict, datasets=datasets) - - - +import pickle +import numpy as np + +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.io.data_bundle import DataBundle +from fastNLP.io.dataset_loader import JsonLoader +from fastNLP.core.const import Const + +from tools.logger import * + +WORD_PAD = "[PAD]" +WORD_UNK = "[UNK]" +DOMAIN_UNK = "X" +TAG_UNK = "X" + + +class SummarizationLoader(JsonLoader): + """ + 读取summarization数据集,读取的DataSet包含fields:: + + text: list(str),document + summary: list(str), summary + text_wd: list(list(str)),tokenized document + summary_wd: list(list(str)), tokenized summary + labels: list(int), + flatten_label: list(int), 0 or 1, flatten labels + domain: str, optional + tag: list(str), optional + + 数据来源: CNN_DailyMail Newsroom DUC + """ + + def __init__(self): + super(SummarizationLoader, self).__init__() + + def _load(self, path): + ds = super(SummarizationLoader, self)._load(path) + + def _lower_text(text_list): + return [text.lower() for text in text_list] + + def _split_list(text_list): + return [text.split() for text in text_list] + + def _convert_label(label, sent_len): + np_label = np.zeros(sent_len, dtype=int) + if label != []: + np_label[np.array(label)] = 1 + return np_label.tolist() + + ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') + ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') + ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') + ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') + ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") + + return ds + + def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): + """ + :param paths: dict path for each dataset + :param vocab_size: int max_size for vocab + :param vocab_path: str vocab path + :param sent_max_len: int max token number of the sentence + :param doc_max_timesteps: int max sentence number of the document + :param domain: bool build vocab for publication, use 'X' for unknown + :param tag: bool build vocab for tag, use 'X' for unknown + :param load_vocab_file: bool build vocab (False) or load vocab (True) + :return: DataBundle + datasets: dict keys correspond to the paths dict + vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) + embeddings: optional + """ + + def _pad_sent(text_wd): + pad_text_wd = [] + for sent_wd in text_wd: + if len(sent_wd) < sent_max_len: + pad_num = sent_max_len - len(sent_wd) + sent_wd.extend([WORD_PAD] * pad_num) + else: + sent_wd = sent_wd[:sent_max_len] + pad_text_wd.append(sent_wd) + return pad_text_wd + + def _token_mask(text_wd): + token_mask_list = [] + for sent_wd in text_wd: + token_num = len(sent_wd) + if token_num < sent_max_len: + mask = [1] * token_num + [0] * (sent_max_len - token_num) + else: + mask = [1] * sent_max_len + token_mask_list.append(mask) + return token_mask_list + + def _pad_label(label): + text_len = len(label) + if text_len < doc_max_timesteps: + pad_label = label + [0] * (doc_max_timesteps - text_len) + else: + pad_label = label[:doc_max_timesteps] + return pad_label + + def _pad_doc(text_wd): + text_len = len(text_wd) + if text_len < doc_max_timesteps: + padding = [WORD_PAD] * sent_max_len + pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) + else: + pad_text = text_wd[:doc_max_timesteps] + return pad_text + + def _sent_mask(text_wd): + text_len = len(text_wd) + if text_len < doc_max_timesteps: + sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) + else: + sent_mask = [1] * doc_max_timesteps + return sent_mask + + + datasets = {} + train_ds = None + for key, value in paths.items(): + ds = self.load(value) + # pad sent + ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") + ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") + # pad document + ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") + ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") + ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") + + # rename field + ds.rename_field("pad_text", Const.INPUT) + ds.rename_field("seq_len", Const.INPUT_LEN) + ds.rename_field("pad_label", Const.TARGET) + + # set input and target + ds.set_input(Const.INPUT, Const.INPUT_LEN) + ds.set_target(Const.TARGET, Const.INPUT_LEN) + + datasets[key] = ds + if "train" in key: + train_ds = datasets[key] + + vocab_dict = {} + if load_vocab_file == False: + logger.info("[INFO] Build new vocab from training dataset!") + if train_ds == None: + raise ValueError("Lack train file to build vocabulary!") + + vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) + vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) + vocab_dict["vocab"] = vocabs + else: + logger.info("[INFO] Load existing vocab from %s!" % vocab_path) + word_list = [] + with open(vocab_path, 'r', encoding='utf8') as vocab_f: + cnt = 2 # pad and unk + for line in vocab_f: + pieces = line.split("\t") + word_list.append(pieces[0]) + cnt += 1 + if cnt > vocab_size: + break + vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) + vocabs.add_word_lst(word_list) + vocabs.build_vocab() + vocab_dict["vocab"] = vocabs + + if domain == True: + domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) + domaindict.from_dataset(train_ds, field_name="publication") + vocab_dict["domain"] = domaindict + if tag == True: + tagdict = Vocabulary(padding=None, unknown=TAG_UNK) + tagdict.from_dataset(train_ds, field_name="tag") + vocab_dict["tag"] = tagdict + + for ds in datasets.values(): + vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) + + return DataBundle(vocabs=vocab_dict, datasets=datasets) + + + diff --git a/reproduction/Summarization/BertSum/dataloader.py b/reproduction/Summarization/BertSum/dataloader.py index c5201261..6af797e4 100644 --- a/reproduction/Summarization/BertSum/dataloader.py +++ b/reproduction/Summarization/BertSum/dataloader.py @@ -3,7 +3,7 @@ from datetime import timedelta from fastNLP.io.dataset_loader import JsonLoader from fastNLP.modules.encoder._bert import BertTokenizer -from fastNLP.io.base_loader import DataBundle +from fastNLP.io.data_bundle import DataBundle from fastNLP.core.const import Const class BertData(JsonLoader): diff --git a/reproduction/coreference_resolution/data_load/cr_loader.py b/reproduction/coreference_resolution/data_load/cr_loader.py index a424b0d1..5ed73473 100644 --- a/reproduction/coreference_resolution/data_load/cr_loader.py +++ b/reproduction/coreference_resolution/data_load/cr_loader.py @@ -1,7 +1,7 @@ from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance from fastNLP.io.file_reader import _read_json from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataBundle +from fastNLP.io.data_bundle import DataBundle from reproduction.coreference_resolution.model.config import Config import reproduction.coreference_resolution.model.preprocess as preprocess diff --git a/reproduction/joint_cws_parse/data/data_loader.py b/reproduction/joint_cws_parse/data/data_loader.py index 3e6fec4b..4df46b04 100644 --- a/reproduction/joint_cws_parse/data/data_loader.py +++ b/reproduction/joint_cws_parse/data/data_loader.py @@ -1,6 +1,6 @@ -from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io.data_bundle import DataSetLoader, DataBundle from fastNLP.io.data_loader import ConllLoader import numpy as np diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index bba26a8a..f13618aa 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -9,7 +9,7 @@ from typing import Union, Dict from fastNLP.core.const import Const from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataBundle, DataSetLoader +from fastNLP.io.data_bundle import DataBundle, DataSetLoader from fastNLP.io.dataset_loader import JsonLoader, CSVLoader from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.modules.encoder._bert import BertTokenizer diff --git a/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py index 0d292bdc..a2ee4663 100644 --- a/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py +++ b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py @@ -1,6 +1,6 @@ -from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io.data_bundle import DataSetLoader, DataBundle from fastNLP.io import ConllLoader from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 from fastNLP import Const diff --git a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py index 3c82d814..5f69c0ad 100644 --- a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py +++ b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py @@ -1,7 +1,7 @@ from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io.data_bundle import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance diff --git a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py index 1aeddcf8..0af4681e 100644 --- a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py +++ b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py @@ -1,6 +1,6 @@ from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io.data_bundle import DataSetLoader, DataBundle from typing import Union, Dict from fastNLP import Vocabulary from fastNLP import Const diff --git a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py index a6070f39..25c6f29b 100644 --- a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py +++ b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py @@ -1,5 +1,5 @@ from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io.data_bundle import DataSetLoader, DataBundle from typing import Union, Dict from fastNLP import DataSet from fastNLP import Vocabulary diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py index 94244431..1585fe44 100644 --- a/reproduction/text_classification/data/IMDBLoader.py +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -1,6 +1,6 @@ from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io.data_bundle import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance diff --git a/reproduction/text_classification/data/MTL16Loader.py b/reproduction/text_classification/data/MTL16Loader.py index 68969069..225fffe6 100644 --- a/reproduction/text_classification/data/MTL16Loader.py +++ b/reproduction/text_classification/data/MTL16Loader.py @@ -1,6 +1,6 @@ from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io.data_bundle import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance diff --git a/reproduction/text_classification/data/sstloader.py b/reproduction/text_classification/data/sstloader.py index fa4d1837..b635a14a 100644 --- a/reproduction/text_classification/data/sstloader.py +++ b/reproduction/text_classification/data/sstloader.py @@ -1,6 +1,6 @@ from typing import Iterable from nltk import Tree -from fastNLP.io.base_loader import DataBundle, DataSetLoader +from fastNLP.io.data_bundle import DataBundle, DataSetLoader from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP import DataSet from fastNLP import Instance diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index d2272a88..1f7634fc 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -4,7 +4,7 @@ from typing import Iterable from fastNLP import DataSet, Instance, Vocabulary from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io import JsonLoader -from fastNLP.io.base_loader import DataBundle,DataSetLoader +from fastNLP.io.data_bundle import DataBundle,DataSetLoader from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict