@@ -20,7 +20,7 @@ server: | |||
cd build/html && python -m http.server | |||
dev: | |||
rm -rf build/html && make html && make server | |||
rm -rf build && make html && make server | |||
.PHONY: help Makefile | |||
@@ -59,7 +59,10 @@ def clear(path='./source/'): | |||
else: | |||
shorten(path + file, to_delete) | |||
for file in to_delete: | |||
os.remove(path + file + ".rst") | |||
try: | |||
os.remove(path + file + ".rst") | |||
except: | |||
pass | |||
clear() |
@@ -1,7 +0,0 @@ | |||
fastNLP.io.base\_loader | |||
======================= | |||
.. automodule:: fastNLP.io.base_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -0,0 +1,7 @@ | |||
fastNLP.io.data\_bundle | |||
======================= | |||
.. automodule:: fastNLP.io.data_bundle | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -20,7 +20,7 @@ Submodules | |||
.. toctree:: | |||
fastNLP.io.base_loader | |||
fastNLP.io.data_bundle | |||
fastNLP.io.dataset_loader | |||
fastNLP.io.embed_loader | |||
fastNLP.io.file_utils | |||
@@ -12,10 +12,9 @@ | |||
这些类的使用方法如下: | |||
""" | |||
__all__ = [ | |||
'EmbedLoader', | |||
'DataBundle', | |||
'DataSetLoader', | |||
'EmbedLoader', | |||
'YelpLoader', | |||
'YelpFullLoader', | |||
@@ -69,7 +68,7 @@ __all__ = [ | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .base_loader import DataBundle, DataSetLoader | |||
from .data_bundle import DataBundle | |||
from .dataset_loader import CSVLoader, JsonLoader | |||
from .model_io import ModelLoader, ModelSaver | |||
@@ -1,313 +0,0 @@ | |||
""" | |||
用于读入和处理和保存 config 文件 | |||
.. todo:: | |||
这个模块中的类可能被抛弃? | |||
""" | |||
__all__ = [ | |||
"ConfigLoader", | |||
"ConfigSection", | |||
"ConfigSaver" | |||
] | |||
import configparser | |||
import json | |||
import os | |||
from .base_loader import BaseLoader | |||
class ConfigLoader(BaseLoader): | |||
""" | |||
别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader` | |||
读取配置文件的Loader | |||
:param str data_path: 配置文件的路径 | |||
""" | |||
def __init__(self, data_path=None): | |||
super(ConfigLoader, self).__init__() | |||
if data_path is not None: | |||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||
@staticmethod | |||
def parse(string): | |||
raise NotImplementedError | |||
@staticmethod | |||
def load_config(file_path, sections): | |||
""" | |||
把配置文件的section 存入提供的 ``sections`` 中 | |||
:param str file_path: 配置文件的路径 | |||
:param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection` | |||
Example:: | |||
test_args = ConfigSection() | |||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
""" | |||
assert isinstance(sections, dict) | |||
cfg = configparser.ConfigParser() | |||
if not os.path.exists(file_path): | |||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||
cfg.read(file_path) | |||
for s in sections: | |||
attr_list = [i for i in sections[s].__dict__.keys() if | |||
not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||
if s not in cfg: | |||
print('section %s not found in config file' % (s)) | |||
continue | |||
gen_sec = cfg[s] | |||
for attr in gen_sec.keys(): | |||
try: | |||
val = json.loads(gen_sec[attr]) | |||
# print(s, attr, val, type(val)) | |||
if attr in attr_list: | |||
assert type(val) == type(getattr(sections[s], attr)), \ | |||
'type not match, except %s but got %s' % \ | |||
(type(getattr(sections[s], attr)), type(val)) | |||
""" | |||
if attr in attr_list then check its type and | |||
update its value. | |||
else add a new attr in sections[s] | |||
""" | |||
setattr(sections[s], attr, val) | |||
except Exception as e: | |||
print("cannot load attribute %s in section %s" | |||
% (attr, s)) | |||
pass | |||
class ConfigSection(object): | |||
""" | |||
别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection` | |||
ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 | |||
""" | |||
def __init__(self): | |||
super(ConfigSection, self).__init__() | |||
def __getitem__(self, key): | |||
""" | |||
:param key: str, the name of the attribute | |||
:return attr: the value of this attribute | |||
if key not in self.__dict__.keys(): | |||
return self[key] | |||
else: | |||
raise AttributeError | |||
""" | |||
if key in self.__dict__.keys(): | |||
return getattr(self, key) | |||
raise AttributeError("do NOT have attribute %s" % key) | |||
def __setitem__(self, key, value): | |||
""" | |||
:param key: str, the name of the attribute | |||
:param value: the value of this attribute | |||
if key not in self.__dict__.keys(): | |||
self[key] will be added | |||
else: | |||
self[key] will be updated | |||
""" | |||
if key in self.__dict__.keys(): | |||
if not isinstance(value, type(getattr(self, key))): | |||
raise AttributeError("attr %s except %s but got %s" % | |||
(key, str(type(getattr(self, key))), str(type(value)))) | |||
setattr(self, key, value) | |||
def __contains__(self, item): | |||
""" | |||
:param item: The key of item. | |||
:return: True if the key in self.__dict__.keys() else False. | |||
""" | |||
return item in self.__dict__.keys() | |||
def __eq__(self, other): | |||
"""Overwrite the == operator | |||
:param other: Another ConfigSection() object which to be compared. | |||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||
""" | |||
for k in self.__dict__.keys(): | |||
if k not in other.__dict__.keys(): | |||
return False | |||
if getattr(self, k) != getattr(self, k): | |||
return False | |||
for k in other.__dict__.keys(): | |||
if k not in self.__dict__.keys(): | |||
return False | |||
if getattr(self, k) != getattr(self, k): | |||
return False | |||
return True | |||
def __ne__(self, other): | |||
"""Overwrite the != operator | |||
:param other: | |||
:return: | |||
""" | |||
return not self.__eq__(other) | |||
@property | |||
def data(self): | |||
return self.__dict__ | |||
class ConfigSaver(object): | |||
""" | |||
别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver` | |||
ConfigSaver 是用来存储配置文件并解决相关冲突的类 | |||
:param str file_path: 配置文件的路径 | |||
""" | |||
def __init__(self, file_path): | |||
self.file_path = file_path | |||
if not os.path.exists(self.file_path): | |||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||
def _get_section(self, sect_name): | |||
""" | |||
This is the function to get the section with the section name. | |||
:param sect_name: The name of section what wants to load. | |||
:return: The section. | |||
""" | |||
sect = ConfigSection() | |||
ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||
return sect | |||
def _read_section(self): | |||
""" | |||
This is the function to read sections from the config file. | |||
:return: sect_list, sect_key_list | |||
sect_list: A list of ConfigSection(). | |||
sect_key_list: A list of names in sect_list. | |||
""" | |||
sect_name = None | |||
sect_list = {} | |||
sect_key_list = [] | |||
single_section = {} | |||
single_section_key = [] | |||
with open(self.file_path, 'r') as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
if line.startswith('[') and line.endswith(']\n'): | |||
if sect_name is None: | |||
pass | |||
else: | |||
sect_list[sect_name] = single_section, single_section_key | |||
single_section = {} | |||
single_section_key = [] | |||
sect_key_list.append(sect_name) | |||
sect_name = line[1: -2] | |||
continue | |||
if line.startswith('#'): | |||
single_section[line] = '#' | |||
single_section_key.append(line) | |||
continue | |||
if line.startswith('\n'): | |||
single_section_key.append('\n') | |||
continue | |||
if '=' not in line: | |||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||
key = line.split('=', maxsplit=1)[0].strip() | |||
value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||
single_section[key] = value | |||
single_section_key.append(key) | |||
if sect_name is not None: | |||
sect_list[sect_name] = single_section, single_section_key | |||
sect_key_list.append(sect_name) | |||
return sect_list, sect_key_list | |||
def _write_section(self, sect_list, sect_key_list): | |||
""" | |||
This is the function to write config file with section list and name list. | |||
:param sect_list: A list of ConfigSection() need to be writen into file. | |||
:param sect_key_list: A list of name of sect_list. | |||
:return: | |||
""" | |||
with open(self.file_path, 'w') as f: | |||
for sect_key in sect_key_list: | |||
single_section, single_section_key = sect_list[sect_key] | |||
f.write('[' + sect_key + ']\n') | |||
for key in single_section_key: | |||
if key == '\n': | |||
f.write('\n') | |||
continue | |||
if single_section[key] == '#': | |||
f.write(key) | |||
continue | |||
f.write(key + ' = ' + single_section[key]) | |||
f.write('\n') | |||
def save_config_file(self, section_name, section): | |||
""" | |||
这个方法可以用来修改并保存配置文件中单独的一个 section | |||
:param str section_name: 需要保存的 section 的名字. | |||
:param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型 | |||
""" | |||
section_file = self._get_section(section_name) | |||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||
# append this section to config file | |||
with open(self.file_path, 'a') as f: | |||
f.write('[' + section_name + ']\n') | |||
for k in section.__dict__.keys(): | |||
f.write(k + ' = ') | |||
if isinstance(section[k], str): | |||
f.write('\"' + str(section[k]) + '\"\n\n') | |||
else: | |||
f.write(str(section[k]) + '\n\n') | |||
else: | |||
# the section exists | |||
change_file = False | |||
for k in section.__dict__.keys(): | |||
if k not in section_file: | |||
# find a new key in this section | |||
change_file = True | |||
break | |||
if section_file[k] != section[k]: | |||
change_file = True | |||
break | |||
if not change_file: | |||
return | |||
sect_list, sect_key_list = self._read_section() | |||
if section_name not in sect_key_list: | |||
raise AttributeError() | |||
sect, sect_key = sect_list[section_name] | |||
for k in section.__dict__.keys(): | |||
if k not in sect_key: | |||
if sect_key[-1] != '\n': | |||
sect_key.append('\n') | |||
sect_key.append(k) | |||
sect[k] = str(section[k]) | |||
if isinstance(section[k], str): | |||
sect[k] = "\"" + sect[k] + "\"" | |||
sect[k] = sect[k] + "\n" | |||
sect_list[section_name] = sect, sect_key | |||
self._write_section(sect_list, sect_key_list) |
@@ -1,7 +1,5 @@ | |||
__all__ = [ | |||
"BaseLoader", | |||
'DataBundle', | |||
'DataSetLoader', | |||
] | |||
import _pickle as pickle |
@@ -1,11 +1,11 @@ | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..base_loader import DataSetLoader | |||
from ..data_bundle import DataSetLoader | |||
from ..file_reader import _read_conll | |||
from typing import Union, Dict | |||
from ..utils import check_loader_paths | |||
from ..base_loader import DataBundle | |||
from ..data_bundle import DataBundle | |||
class ConllLoader(DataSetLoader): | |||
""" | |||
@@ -2,7 +2,7 @@ | |||
from typing import Union, Dict | |||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||
from ..base_loader import DataSetLoader, DataBundle | |||
from ..data_bundle import DataSetLoader, DataBundle | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
@@ -4,7 +4,7 @@ from typing import Union, Dict, List | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from ..data_bundle import DataBundle, DataSetLoader | |||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ...modules.encoder.bert import BertTokenizer | |||
@@ -1,7 +1,7 @@ | |||
from typing import Union, Dict | |||
from ..base_loader import DataBundle | |||
from ..data_bundle import DataBundle | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||
from ...core.const import Const | |||
@@ -1,5 +1,5 @@ | |||
from ..base_loader import DataSetLoader | |||
from ..data_bundle import DataSetLoader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.const import Const | |||
@@ -2,7 +2,7 @@ | |||
from typing import Union, Dict | |||
from nltk import Tree | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from ..data_bundle import DataBundle, DataSetLoader | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
@@ -6,7 +6,7 @@ from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from ..data_bundle import DataBundle, DataSetLoader | |||
from typing import Union, Dict | |||
from ..utils import check_loader_paths, get_tokenizer | |||
@@ -26,7 +26,7 @@ __all__ = [ | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from .file_reader import _read_csv, _read_json | |||
from .base_loader import DataSetLoader | |||
from .data_bundle import DataSetLoader | |||
class JsonLoader(DataSetLoader): | |||
@@ -9,7 +9,7 @@ import warnings | |||
import numpy as np | |||
from ..core.vocabulary import Vocabulary | |||
from .base_loader import BaseLoader | |||
from .data_bundle import BaseLoader | |||
from ..core.utils import Option | |||
@@ -44,6 +44,8 @@ fastNLP 目前提供了如下的 Loader | |||
""" | |||
__all__ = [ | |||
'Loader', | |||
'YelpLoader', | |||
'YelpFullLoader', | |||
'YelpPolarityLoader', | |||
@@ -57,7 +59,6 @@ __all__ = [ | |||
'OntoNotesNERLoader', | |||
'CTBLoader', | |||
'Loader', | |||
'CSVLoader', | |||
'JsonLoader', | |||
@@ -7,6 +7,7 @@ import random | |||
import shutil | |||
import numpy as np | |||
class YelpLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` | |||
@@ -14,6 +15,7 @@ class YelpLoader(Loader): | |||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | |||
Example:: | |||
"1","I got 'new' tires from the..." | |||
"1","Don't waste your time..." | |||
@@ -28,11 +30,11 @@ class YelpLoader(Loader): | |||
"...", "..." | |||
""" | |||
def __init__(self): | |||
super(YelpLoader, self).__init__() | |||
def _load(self, path: str=None): | |||
def _load(self, path: str = None): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
@@ -69,12 +71,12 @@ class YelpFullLoader(YelpLoader): | |||
:param int seed: 划分dev时的随机数种子 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'yelp-review-full' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 | |||
re_download = True | |||
if dev_ratio>0: | |||
if dev_ratio > 0: | |||
dev_line_count = 0 | |||
tr_line_count = 0 | |||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||
@@ -83,14 +85,14 @@ class YelpFullLoader(YelpLoader): | |||
tr_line_count += 1 | |||
for line in f2: | |||
dev_line_count += 1 | |||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||
re_download = True | |||
else: | |||
re_download = False | |||
if re_download: | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
@@ -109,7 +111,7 @@ class YelpFullLoader(YelpLoader): | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||
return data_dir | |||
@@ -131,7 +133,7 @@ class YelpPolarityLoader(YelpLoader): | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 | |||
re_download = True | |||
if dev_ratio>0: | |||
if dev_ratio > 0: | |||
dev_line_count = 0 | |||
tr_line_count = 0 | |||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||
@@ -140,14 +142,14 @@ class YelpPolarityLoader(YelpLoader): | |||
tr_line_count += 1 | |||
for line in f2: | |||
dev_line_count += 1 | |||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||
re_download = True | |||
else: | |||
re_download = False | |||
if re_download: | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
@@ -166,7 +168,7 @@ class YelpPolarityLoader(YelpLoader): | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||
return data_dir | |||
@@ -185,10 +187,10 @@ class IMDBLoader(Loader): | |||
"...", "..." | |||
""" | |||
def __init__(self): | |||
super(IMDBLoader, self).__init__() | |||
def _load(self, path: str): | |||
dataset = DataSet() | |||
with open(path, 'r', encoding="utf-8") as f: | |||
@@ -201,12 +203,12 @@ class IMDBLoader(Loader): | |||
words = parts[1] | |||
if words: | |||
dataset.append(Instance(raw_words=words, target=target)) | |||
if len(dataset) == 0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||
""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
@@ -221,9 +223,9 @@ class IMDBLoader(Loader): | |||
""" | |||
dataset_name = 'aclImdb' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||
re_download = True | |||
if dev_ratio>0: | |||
if dev_ratio > 0: | |||
dev_line_count = 0 | |||
tr_line_count = 0 | |||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ | |||
@@ -232,14 +234,14 @@ class IMDBLoader(Loader): | |||
tr_line_count += 1 | |||
for line in f2: | |||
dev_line_count += 1 | |||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||
re_download = True | |||
else: | |||
re_download = False | |||
if re_download: | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
@@ -258,7 +260,7 @@ class IMDBLoader(Loader): | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||
return data_dir | |||
@@ -278,10 +280,10 @@ class SSTLoader(Loader): | |||
raw_words列是str。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
""" | |||
从path读取SST文件 | |||
@@ -296,7 +298,7 @@ class SSTLoader(Loader): | |||
if line: | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self): | |||
""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
@@ -323,10 +325,10 @@ class SST2Loader(Loader): | |||
test的DataSet没有target列。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
""" | |||
从path读取SST2文件 | |||
@@ -335,7 +337,7 @@ class SST2Loader(Loader): | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if 'test' in os.path.split(path)[1]: | |||
@@ -356,7 +358,7 @@ class SST2Loader(Loader): | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words, target=target)) | |||
return ds | |||
def download(self): | |||
""" | |||
自动下载数据集,如果你使用了该数据集,请引用以下的文章 | |||
@@ -2,17 +2,21 @@ from ...core.dataset import DataSet | |||
from .. import DataBundle | |||
from ..utils import check_loader_paths | |||
from typing import Union, Dict | |||
import os | |||
from ..file_utils import _get_dataset_url, get_cache_path, cached_path | |||
class Loader: | |||
""" | |||
各种数据 Loader 的基类,提供了 API 的参考. | |||
""" | |||
def __init__(self): | |||
pass | |||
def _load(self, path:str) -> DataSet: | |||
def _load(self, path: str) -> DataSet: | |||
raise NotImplementedError | |||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||
@@ -22,31 +26,25 @@ class Loader: | |||
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||
名包含'train'、 'dev'、 'test'则会报错 | |||
Example:: | |||
名包含'train'、 'dev'、 'test'则会报错:: | |||
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | |||
tr_data = data_bundle.datasets['train'] | |||
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | |||
(2) 传入文件路径 | |||
Example:: | |||
(2) 传入文件路径:: | |||
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | |||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test | |||
Example:: | |||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
dev_data = data_bundle.datasets['dev'] | |||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||
:return: 返回的 :class:`~fastNLP.io.DataBundle` | |||
""" | |||
if paths is None: | |||
paths = self.download() | |||
@@ -54,10 +52,10 @@ class Loader: | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | |||
def _get_dataset_path(self, dataset_name): | |||
""" | |||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | |||
@@ -65,11 +63,9 @@ class Loader: | |||
:param str dataset_name: 数据集的名称 | |||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | |||
""" | |||
default_cache_path = get_cache_path() | |||
url = _get_dataset_url(dataset_name) | |||
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | |||
return output_dir | |||
@@ -203,7 +203,8 @@ class QNLILoader(JsonLoader): | |||
""" | |||
如果您的实验使用到了该数据,请引用 | |||
TODO 补充 | |||
.. todo:: | |||
补充 | |||
:return: | |||
""" | |||
@@ -8,7 +8,7 @@ __all__ = [ | |||
import torch | |||
from .base_loader import BaseLoader | |||
from .data_bundle import BaseLoader | |||
class ModelLoader(BaseLoader): | |||
@@ -1,6 +1,6 @@ | |||
from nltk import Tree | |||
from ..base_loader import DataBundle | |||
from ..data_bundle import DataBundle | |||
from ...core.vocabulary import Vocabulary | |||
from ...core.const import Const | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
@@ -1,188 +1,188 @@ | |||
import pickle | |||
import numpy as np | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataBundle | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.core.const import Const | |||
from tools.logger import * | |||
WORD_PAD = "[PAD]" | |||
WORD_UNK = "[UNK]" | |||
DOMAIN_UNK = "X" | |||
TAG_UNK = "X" | |||
class SummarizationLoader(JsonLoader): | |||
""" | |||
读取summarization数据集,读取的DataSet包含fields:: | |||
text: list(str),document | |||
summary: list(str), summary | |||
text_wd: list(list(str)),tokenized document | |||
summary_wd: list(list(str)), tokenized summary | |||
labels: list(int), | |||
flatten_label: list(int), 0 or 1, flatten labels | |||
domain: str, optional | |||
tag: list(str), optional | |||
数据来源: CNN_DailyMail Newsroom DUC | |||
""" | |||
def __init__(self): | |||
super(SummarizationLoader, self).__init__() | |||
def _load(self, path): | |||
ds = super(SummarizationLoader, self)._load(path) | |||
def _lower_text(text_list): | |||
return [text.lower() for text in text_list] | |||
def _split_list(text_list): | |||
return [text.split() for text in text_list] | |||
def _convert_label(label, sent_len): | |||
np_label = np.zeros(sent_len, dtype=int) | |||
if label != []: | |||
np_label[np.array(label)] = 1 | |||
return np_label.tolist() | |||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||
return ds | |||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||
""" | |||
:param paths: dict path for each dataset | |||
:param vocab_size: int max_size for vocab | |||
:param vocab_path: str vocab path | |||
:param sent_max_len: int max token number of the sentence | |||
:param doc_max_timesteps: int max sentence number of the document | |||
:param domain: bool build vocab for publication, use 'X' for unknown | |||
:param tag: bool build vocab for tag, use 'X' for unknown | |||
:param load_vocab_file: bool build vocab (False) or load vocab (True) | |||
:return: DataBundle | |||
datasets: dict keys correspond to the paths dict | |||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||
embeddings: optional | |||
""" | |||
def _pad_sent(text_wd): | |||
pad_text_wd = [] | |||
for sent_wd in text_wd: | |||
if len(sent_wd) < sent_max_len: | |||
pad_num = sent_max_len - len(sent_wd) | |||
sent_wd.extend([WORD_PAD] * pad_num) | |||
else: | |||
sent_wd = sent_wd[:sent_max_len] | |||
pad_text_wd.append(sent_wd) | |||
return pad_text_wd | |||
def _token_mask(text_wd): | |||
token_mask_list = [] | |||
for sent_wd in text_wd: | |||
token_num = len(sent_wd) | |||
if token_num < sent_max_len: | |||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
else: | |||
mask = [1] * sent_max_len | |||
token_mask_list.append(mask) | |||
return token_mask_list | |||
def _pad_label(label): | |||
text_len = len(label) | |||
if text_len < doc_max_timesteps: | |||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_label = label[:doc_max_timesteps] | |||
return pad_label | |||
def _pad_doc(text_wd): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
padding = [WORD_PAD] * sent_max_len | |||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_text = text_wd[:doc_max_timesteps] | |||
return pad_text | |||
def _sent_mask(text_wd): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
sent_mask = [1] * doc_max_timesteps | |||
return sent_mask | |||
datasets = {} | |||
train_ds = None | |||
for key, value in paths.items(): | |||
ds = self.load(value) | |||
# pad sent | |||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||
# pad document | |||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||
# rename field | |||
ds.rename_field("pad_text", Const.INPUT) | |||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||
ds.rename_field("pad_label", Const.TARGET) | |||
# set input and target | |||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||
datasets[key] = ds | |||
if "train" in key: | |||
train_ds = datasets[key] | |||
vocab_dict = {} | |||
if load_vocab_file == False: | |||
logger.info("[INFO] Build new vocab from training dataset!") | |||
if train_ds == None: | |||
raise ValueError("Lack train file to build vocabulary!") | |||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||
vocab_dict["vocab"] = vocabs | |||
else: | |||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||
word_list = [] | |||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||
cnt = 2 # pad and unk | |||
for line in vocab_f: | |||
pieces = line.split("\t") | |||
word_list.append(pieces[0]) | |||
cnt += 1 | |||
if cnt > vocab_size: | |||
break | |||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.add_word_lst(word_list) | |||
vocabs.build_vocab() | |||
vocab_dict["vocab"] = vocabs | |||
if domain == True: | |||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
domaindict.from_dataset(train_ds, field_name="publication") | |||
vocab_dict["domain"] = domaindict | |||
if tag == True: | |||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||
tagdict.from_dataset(train_ds, field_name="tag") | |||
vocab_dict["tag"] = tagdict | |||
for ds in datasets.values(): | |||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||
import pickle | |||
import numpy as np | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.core.const import Const | |||
from tools.logger import * | |||
WORD_PAD = "[PAD]" | |||
WORD_UNK = "[UNK]" | |||
DOMAIN_UNK = "X" | |||
TAG_UNK = "X" | |||
class SummarizationLoader(JsonLoader): | |||
""" | |||
读取summarization数据集,读取的DataSet包含fields:: | |||
text: list(str),document | |||
summary: list(str), summary | |||
text_wd: list(list(str)),tokenized document | |||
summary_wd: list(list(str)), tokenized summary | |||
labels: list(int), | |||
flatten_label: list(int), 0 or 1, flatten labels | |||
domain: str, optional | |||
tag: list(str), optional | |||
数据来源: CNN_DailyMail Newsroom DUC | |||
""" | |||
def __init__(self): | |||
super(SummarizationLoader, self).__init__() | |||
def _load(self, path): | |||
ds = super(SummarizationLoader, self)._load(path) | |||
def _lower_text(text_list): | |||
return [text.lower() for text in text_list] | |||
def _split_list(text_list): | |||
return [text.split() for text in text_list] | |||
def _convert_label(label, sent_len): | |||
np_label = np.zeros(sent_len, dtype=int) | |||
if label != []: | |||
np_label[np.array(label)] = 1 | |||
return np_label.tolist() | |||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||
return ds | |||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||
""" | |||
:param paths: dict path for each dataset | |||
:param vocab_size: int max_size for vocab | |||
:param vocab_path: str vocab path | |||
:param sent_max_len: int max token number of the sentence | |||
:param doc_max_timesteps: int max sentence number of the document | |||
:param domain: bool build vocab for publication, use 'X' for unknown | |||
:param tag: bool build vocab for tag, use 'X' for unknown | |||
:param load_vocab_file: bool build vocab (False) or load vocab (True) | |||
:return: DataBundle | |||
datasets: dict keys correspond to the paths dict | |||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||
embeddings: optional | |||
""" | |||
def _pad_sent(text_wd): | |||
pad_text_wd = [] | |||
for sent_wd in text_wd: | |||
if len(sent_wd) < sent_max_len: | |||
pad_num = sent_max_len - len(sent_wd) | |||
sent_wd.extend([WORD_PAD] * pad_num) | |||
else: | |||
sent_wd = sent_wd[:sent_max_len] | |||
pad_text_wd.append(sent_wd) | |||
return pad_text_wd | |||
def _token_mask(text_wd): | |||
token_mask_list = [] | |||
for sent_wd in text_wd: | |||
token_num = len(sent_wd) | |||
if token_num < sent_max_len: | |||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
else: | |||
mask = [1] * sent_max_len | |||
token_mask_list.append(mask) | |||
return token_mask_list | |||
def _pad_label(label): | |||
text_len = len(label) | |||
if text_len < doc_max_timesteps: | |||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_label = label[:doc_max_timesteps] | |||
return pad_label | |||
def _pad_doc(text_wd): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
padding = [WORD_PAD] * sent_max_len | |||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_text = text_wd[:doc_max_timesteps] | |||
return pad_text | |||
def _sent_mask(text_wd): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
sent_mask = [1] * doc_max_timesteps | |||
return sent_mask | |||
datasets = {} | |||
train_ds = None | |||
for key, value in paths.items(): | |||
ds = self.load(value) | |||
# pad sent | |||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||
# pad document | |||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||
# rename field | |||
ds.rename_field("pad_text", Const.INPUT) | |||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||
ds.rename_field("pad_label", Const.TARGET) | |||
# set input and target | |||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||
datasets[key] = ds | |||
if "train" in key: | |||
train_ds = datasets[key] | |||
vocab_dict = {} | |||
if load_vocab_file == False: | |||
logger.info("[INFO] Build new vocab from training dataset!") | |||
if train_ds == None: | |||
raise ValueError("Lack train file to build vocabulary!") | |||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||
vocab_dict["vocab"] = vocabs | |||
else: | |||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||
word_list = [] | |||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||
cnt = 2 # pad and unk | |||
for line in vocab_f: | |||
pieces = line.split("\t") | |||
word_list.append(pieces[0]) | |||
cnt += 1 | |||
if cnt > vocab_size: | |||
break | |||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.add_word_lst(word_list) | |||
vocabs.build_vocab() | |||
vocab_dict["vocab"] = vocabs | |||
if domain == True: | |||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
domaindict.from_dataset(train_ds, field_name="publication") | |||
vocab_dict["domain"] = domaindict | |||
if tag == True: | |||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||
tagdict.from_dataset(train_ds, field_name="tag") | |||
vocab_dict["tag"] = tagdict | |||
for ds in datasets.values(): | |||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||
@@ -3,7 +3,7 @@ from datetime import timedelta | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
from fastNLP.io.base_loader import DataBundle | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.core.const import Const | |||
class BertData(JsonLoader): | |||
@@ -1,7 +1,7 @@ | |||
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | |||
from fastNLP.io.file_reader import _read_json | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataBundle | |||
from fastNLP.io.data_bundle import DataBundle | |||
from reproduction.coreference_resolution.model.config import Config | |||
import reproduction.coreference_resolution.model.preprocess as preprocess | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from fastNLP.io.data_loader import ConllLoader | |||
import numpy as np | |||
@@ -9,7 +9,7 @@ from typing import Union, Dict | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||
from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from fastNLP.io import ConllLoader | |||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||
from fastNLP import Const | |||
@@ -1,7 +1,7 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
@@ -1,5 +1,5 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict | |||
from fastNLP import DataSet | |||
from fastNLP import Vocabulary | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -1,6 +1,6 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -1,6 +1,6 @@ | |||
from typing import Iterable | |||
from nltk import Tree | |||
from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||
from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
@@ -4,7 +4,7 @@ from typing import Iterable | |||
from fastNLP import DataSet, Instance, Vocabulary | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io import JsonLoader | |||
from fastNLP.io.base_loader import DataBundle,DataSetLoader | |||
from fastNLP.io.data_bundle import DataBundle,DataSetLoader | |||
from fastNLP.io.embed_loader import EmbeddingOption | |||
from fastNLP.io.file_reader import _read_json | |||
from typing import Union, Dict | |||