@@ -20,7 +20,7 @@ server: | |||||
cd build/html && python -m http.server | cd build/html && python -m http.server | ||||
dev: | dev: | ||||
rm -rf build/html && make html && make server | |||||
rm -rf build && make html && make server | |||||
.PHONY: help Makefile | .PHONY: help Makefile | ||||
@@ -59,7 +59,10 @@ def clear(path='./source/'): | |||||
else: | else: | ||||
shorten(path + file, to_delete) | shorten(path + file, to_delete) | ||||
for file in to_delete: | for file in to_delete: | ||||
os.remove(path + file + ".rst") | |||||
try: | |||||
os.remove(path + file + ".rst") | |||||
except: | |||||
pass | |||||
clear() | clear() |
@@ -1,7 +0,0 @@ | |||||
fastNLP.io.base\_loader | |||||
======================= | |||||
.. automodule:: fastNLP.io.base_loader | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -0,0 +1,7 @@ | |||||
fastNLP.io.data\_bundle | |||||
======================= | |||||
.. automodule:: fastNLP.io.data_bundle | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -20,7 +20,7 @@ Submodules | |||||
.. toctree:: | .. toctree:: | ||||
fastNLP.io.base_loader | |||||
fastNLP.io.data_bundle | |||||
fastNLP.io.dataset_loader | fastNLP.io.dataset_loader | ||||
fastNLP.io.embed_loader | fastNLP.io.embed_loader | ||||
fastNLP.io.file_utils | fastNLP.io.file_utils | ||||
@@ -12,10 +12,9 @@ | |||||
这些类的使用方法如下: | 这些类的使用方法如下: | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'EmbedLoader', | |||||
'DataBundle', | 'DataBundle', | ||||
'DataSetLoader', | |||||
'EmbedLoader', | |||||
'YelpLoader', | 'YelpLoader', | ||||
'YelpFullLoader', | 'YelpFullLoader', | ||||
@@ -69,7 +68,7 @@ __all__ = [ | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .base_loader import DataBundle, DataSetLoader | |||||
from .data_bundle import DataBundle | |||||
from .dataset_loader import CSVLoader, JsonLoader | from .dataset_loader import CSVLoader, JsonLoader | ||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver | ||||
@@ -1,313 +0,0 @@ | |||||
""" | |||||
用于读入和处理和保存 config 文件 | |||||
.. todo:: | |||||
这个模块中的类可能被抛弃? | |||||
""" | |||||
__all__ = [ | |||||
"ConfigLoader", | |||||
"ConfigSection", | |||||
"ConfigSaver" | |||||
] | |||||
import configparser | |||||
import json | |||||
import os | |||||
from .base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader` | |||||
读取配置文件的Loader | |||||
:param str data_path: 配置文件的路径 | |||||
""" | |||||
def __init__(self, data_path=None): | |||||
super(ConfigLoader, self).__init__() | |||||
if data_path is not None: | |||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
@staticmethod | |||||
def parse(string): | |||||
raise NotImplementedError | |||||
@staticmethod | |||||
def load_config(file_path, sections): | |||||
""" | |||||
把配置文件的section 存入提供的 ``sections`` 中 | |||||
:param str file_path: 配置文件的路径 | |||||
:param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection` | |||||
Example:: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
""" | |||||
assert isinstance(sections, dict) | |||||
cfg = configparser.ConfigParser() | |||||
if not os.path.exists(file_path): | |||||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
cfg.read(file_path) | |||||
for s in sections: | |||||
attr_list = [i for i in sections[s].__dict__.keys() if | |||||
not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||||
if s not in cfg: | |||||
print('section %s not found in config file' % (s)) | |||||
continue | |||||
gen_sec = cfg[s] | |||||
for attr in gen_sec.keys(): | |||||
try: | |||||
val = json.loads(gen_sec[attr]) | |||||
# print(s, attr, val, type(val)) | |||||
if attr in attr_list: | |||||
assert type(val) == type(getattr(sections[s], attr)), \ | |||||
'type not match, except %s but got %s' % \ | |||||
(type(getattr(sections[s], attr)), type(val)) | |||||
""" | |||||
if attr in attr_list then check its type and | |||||
update its value. | |||||
else add a new attr in sections[s] | |||||
""" | |||||
setattr(sections[s], attr, val) | |||||
except Exception as e: | |||||
print("cannot load attribute %s in section %s" | |||||
% (attr, s)) | |||||
pass | |||||
class ConfigSection(object): | |||||
""" | |||||
别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection` | |||||
ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 | |||||
""" | |||||
def __init__(self): | |||||
super(ConfigSection, self).__init__() | |||||
def __getitem__(self, key): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:return attr: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
return self[key] | |||||
else: | |||||
raise AttributeError | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
return getattr(self, key) | |||||
raise AttributeError("do NOT have attribute %s" % key) | |||||
def __setitem__(self, key, value): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:param value: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
self[key] will be added | |||||
else: | |||||
self[key] will be updated | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
if not isinstance(value, type(getattr(self, key))): | |||||
raise AttributeError("attr %s except %s but got %s" % | |||||
(key, str(type(getattr(self, key))), str(type(value)))) | |||||
setattr(self, key, value) | |||||
def __contains__(self, item): | |||||
""" | |||||
:param item: The key of item. | |||||
:return: True if the key in self.__dict__.keys() else False. | |||||
""" | |||||
return item in self.__dict__.keys() | |||||
def __eq__(self, other): | |||||
"""Overwrite the == operator | |||||
:param other: Another ConfigSection() object which to be compared. | |||||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
""" | |||||
for k in self.__dict__.keys(): | |||||
if k not in other.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
for k in other.__dict__.keys(): | |||||
if k not in self.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
return True | |||||
def __ne__(self, other): | |||||
"""Overwrite the != operator | |||||
:param other: | |||||
:return: | |||||
""" | |||||
return not self.__eq__(other) | |||||
@property | |||||
def data(self): | |||||
return self.__dict__ | |||||
class ConfigSaver(object): | |||||
""" | |||||
别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver` | |||||
ConfigSaver 是用来存储配置文件并解决相关冲突的类 | |||||
:param str file_path: 配置文件的路径 | |||||
""" | |||||
def __init__(self, file_path): | |||||
self.file_path = file_path | |||||
if not os.path.exists(self.file_path): | |||||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||||
def _get_section(self, sect_name): | |||||
""" | |||||
This is the function to get the section with the section name. | |||||
:param sect_name: The name of section what wants to load. | |||||
:return: The section. | |||||
""" | |||||
sect = ConfigSection() | |||||
ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||||
return sect | |||||
def _read_section(self): | |||||
""" | |||||
This is the function to read sections from the config file. | |||||
:return: sect_list, sect_key_list | |||||
sect_list: A list of ConfigSection(). | |||||
sect_key_list: A list of names in sect_list. | |||||
""" | |||||
sect_name = None | |||||
sect_list = {} | |||||
sect_key_list = [] | |||||
single_section = {} | |||||
single_section_key = [] | |||||
with open(self.file_path, 'r') as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
if line.startswith('[') and line.endswith(']\n'): | |||||
if sect_name is None: | |||||
pass | |||||
else: | |||||
sect_list[sect_name] = single_section, single_section_key | |||||
single_section = {} | |||||
single_section_key = [] | |||||
sect_key_list.append(sect_name) | |||||
sect_name = line[1: -2] | |||||
continue | |||||
if line.startswith('#'): | |||||
single_section[line] = '#' | |||||
single_section_key.append(line) | |||||
continue | |||||
if line.startswith('\n'): | |||||
single_section_key.append('\n') | |||||
continue | |||||
if '=' not in line: | |||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||||
key = line.split('=', maxsplit=1)[0].strip() | |||||
value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||||
single_section[key] = value | |||||
single_section_key.append(key) | |||||
if sect_name is not None: | |||||
sect_list[sect_name] = single_section, single_section_key | |||||
sect_key_list.append(sect_name) | |||||
return sect_list, sect_key_list | |||||
def _write_section(self, sect_list, sect_key_list): | |||||
""" | |||||
This is the function to write config file with section list and name list. | |||||
:param sect_list: A list of ConfigSection() need to be writen into file. | |||||
:param sect_key_list: A list of name of sect_list. | |||||
:return: | |||||
""" | |||||
with open(self.file_path, 'w') as f: | |||||
for sect_key in sect_key_list: | |||||
single_section, single_section_key = sect_list[sect_key] | |||||
f.write('[' + sect_key + ']\n') | |||||
for key in single_section_key: | |||||
if key == '\n': | |||||
f.write('\n') | |||||
continue | |||||
if single_section[key] == '#': | |||||
f.write(key) | |||||
continue | |||||
f.write(key + ' = ' + single_section[key]) | |||||
f.write('\n') | |||||
def save_config_file(self, section_name, section): | |||||
""" | |||||
这个方法可以用来修改并保存配置文件中单独的一个 section | |||||
:param str section_name: 需要保存的 section 的名字. | |||||
:param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型 | |||||
""" | |||||
section_file = self._get_section(section_name) | |||||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||||
# append this section to config file | |||||
with open(self.file_path, 'a') as f: | |||||
f.write('[' + section_name + ']\n') | |||||
for k in section.__dict__.keys(): | |||||
f.write(k + ' = ') | |||||
if isinstance(section[k], str): | |||||
f.write('\"' + str(section[k]) + '\"\n\n') | |||||
else: | |||||
f.write(str(section[k]) + '\n\n') | |||||
else: | |||||
# the section exists | |||||
change_file = False | |||||
for k in section.__dict__.keys(): | |||||
if k not in section_file: | |||||
# find a new key in this section | |||||
change_file = True | |||||
break | |||||
if section_file[k] != section[k]: | |||||
change_file = True | |||||
break | |||||
if not change_file: | |||||
return | |||||
sect_list, sect_key_list = self._read_section() | |||||
if section_name not in sect_key_list: | |||||
raise AttributeError() | |||||
sect, sect_key = sect_list[section_name] | |||||
for k in section.__dict__.keys(): | |||||
if k not in sect_key: | |||||
if sect_key[-1] != '\n': | |||||
sect_key.append('\n') | |||||
sect_key.append(k) | |||||
sect[k] = str(section[k]) | |||||
if isinstance(section[k], str): | |||||
sect[k] = "\"" + sect[k] + "\"" | |||||
sect[k] = sect[k] + "\n" | |||||
sect_list[section_name] = sect, sect_key | |||||
self._write_section(sect_list, sect_key_list) |
@@ -1,7 +1,5 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"BaseLoader", | |||||
'DataBundle', | 'DataBundle', | ||||
'DataSetLoader', | |||||
] | ] | ||||
import _pickle as pickle | import _pickle as pickle |
@@ -1,11 +1,11 @@ | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..base_loader import DataSetLoader | |||||
from ..data_bundle import DataSetLoader | |||||
from ..file_reader import _read_conll | from ..file_reader import _read_conll | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
from ..base_loader import DataBundle | |||||
from ..data_bundle import DataBundle | |||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -2,7 +2,7 @@ | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from ..embed_loader import EmbeddingOption, EmbedLoader | from ..embed_loader import EmbeddingOption, EmbedLoader | ||||
from ..base_loader import DataSetLoader, DataBundle | |||||
from ..data_bundle import DataSetLoader, DataBundle | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
@@ -4,7 +4,7 @@ from typing import Union, Dict, List | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ..base_loader import DataBundle, DataSetLoader | |||||
from ..data_bundle import DataBundle, DataSetLoader | |||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | ||||
from ...modules.encoder.bert import BertTokenizer | from ...modules.encoder.bert import BertTokenizer | ||||
@@ -1,7 +1,7 @@ | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from ..base_loader import DataBundle | |||||
from ..data_bundle import DataBundle | |||||
from ..dataset_loader import CSVLoader | from ..dataset_loader import CSVLoader | ||||
from ...core.vocabulary import Vocabulary, VocabularyOption | from ...core.vocabulary import Vocabulary, VocabularyOption | ||||
from ...core.const import Const | from ...core.const import Const | ||||
@@ -1,5 +1,5 @@ | |||||
from ..base_loader import DataSetLoader | |||||
from ..data_bundle import DataSetLoader | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ...core.const import Const | from ...core.const import Const | ||||
@@ -2,7 +2,7 @@ | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from nltk import Tree | from nltk import Tree | ||||
from ..base_loader import DataBundle, DataSetLoader | |||||
from ..data_bundle import DataBundle, DataSetLoader | |||||
from ..dataset_loader import CSVLoader | from ..dataset_loader import CSVLoader | ||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
@@ -6,7 +6,7 @@ from ...core.const import Const | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ..base_loader import DataBundle, DataSetLoader | |||||
from ..data_bundle import DataBundle, DataSetLoader | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from ..utils import check_loader_paths, get_tokenizer | from ..utils import check_loader_paths, get_tokenizer | ||||
@@ -26,7 +26,7 @@ __all__ = [ | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json | from .file_reader import _read_csv, _read_json | ||||
from .base_loader import DataSetLoader | |||||
from .data_bundle import DataSetLoader | |||||
class JsonLoader(DataSetLoader): | class JsonLoader(DataSetLoader): | ||||
@@ -9,7 +9,7 @@ import warnings | |||||
import numpy as np | import numpy as np | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from .base_loader import BaseLoader | |||||
from .data_bundle import BaseLoader | |||||
from ..core.utils import Option | from ..core.utils import Option | ||||
@@ -44,6 +44,8 @@ fastNLP 目前提供了如下的 Loader | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'Loader', | |||||
'YelpLoader', | 'YelpLoader', | ||||
'YelpFullLoader', | 'YelpFullLoader', | ||||
'YelpPolarityLoader', | 'YelpPolarityLoader', | ||||
@@ -57,7 +59,6 @@ __all__ = [ | |||||
'OntoNotesNERLoader', | 'OntoNotesNERLoader', | ||||
'CTBLoader', | 'CTBLoader', | ||||
'Loader', | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
@@ -7,6 +7,7 @@ import random | |||||
import shutil | import shutil | ||||
import numpy as np | import numpy as np | ||||
class YelpLoader(Loader): | class YelpLoader(Loader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` | 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` | ||||
@@ -14,6 +15,7 @@ class YelpLoader(Loader): | |||||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | ||||
Example:: | Example:: | ||||
"1","I got 'new' tires from the..." | "1","I got 'new' tires from the..." | ||||
"1","Don't waste your time..." | "1","Don't waste your time..." | ||||
@@ -28,11 +30,11 @@ class YelpLoader(Loader): | |||||
"...", "..." | "...", "..." | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(YelpLoader, self).__init__() | super(YelpLoader, self).__init__() | ||||
def _load(self, path: str=None): | |||||
def _load(self, path: str = None): | |||||
ds = DataSet() | ds = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
for line in f: | for line in f: | ||||
@@ -69,12 +71,12 @@ class YelpFullLoader(YelpLoader): | |||||
:param int seed: 划分dev时的随机数种子 | :param int seed: 划分dev时的随机数种子 | ||||
:return: str, 数据集的目录地址 | :return: str, 数据集的目录地址 | ||||
""" | """ | ||||
dataset_name = 'yelp-review-full' | dataset_name = 'yelp-review-full' | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 | if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 | ||||
re_download = True | re_download = True | ||||
if dev_ratio>0: | |||||
if dev_ratio > 0: | |||||
dev_line_count = 0 | dev_line_count = 0 | ||||
tr_line_count = 0 | tr_line_count = 0 | ||||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | ||||
@@ -83,14 +85,14 @@ class YelpFullLoader(YelpLoader): | |||||
tr_line_count += 1 | tr_line_count += 1 | ||||
for line in f2: | for line in f2: | ||||
dev_line_count += 1 | dev_line_count += 1 | ||||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||||
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||||
re_download = True | re_download = True | ||||
else: | else: | ||||
re_download = False | re_download = False | ||||
if re_download: | if re_download: | ||||
shutil.rmtree(data_dir) | shutil.rmtree(data_dir) | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | ||||
if dev_ratio > 0: | if dev_ratio > 0: | ||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | ||||
@@ -109,7 +111,7 @@ class YelpFullLoader(YelpLoader): | |||||
finally: | finally: | ||||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | ||||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | os.remove(os.path.join(data_dir, 'middle_file.csv')) | ||||
return data_dir | return data_dir | ||||
@@ -131,7 +133,7 @@ class YelpPolarityLoader(YelpLoader): | |||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 | if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 | ||||
re_download = True | re_download = True | ||||
if dev_ratio>0: | |||||
if dev_ratio > 0: | |||||
dev_line_count = 0 | dev_line_count = 0 | ||||
tr_line_count = 0 | tr_line_count = 0 | ||||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | ||||
@@ -140,14 +142,14 @@ class YelpPolarityLoader(YelpLoader): | |||||
tr_line_count += 1 | tr_line_count += 1 | ||||
for line in f2: | for line in f2: | ||||
dev_line_count += 1 | dev_line_count += 1 | ||||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||||
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||||
re_download = True | re_download = True | ||||
else: | else: | ||||
re_download = False | re_download = False | ||||
if re_download: | if re_download: | ||||
shutil.rmtree(data_dir) | shutil.rmtree(data_dir) | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | ||||
if dev_ratio > 0: | if dev_ratio > 0: | ||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | ||||
@@ -166,7 +168,7 @@ class YelpPolarityLoader(YelpLoader): | |||||
finally: | finally: | ||||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | ||||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | os.remove(os.path.join(data_dir, 'middle_file.csv')) | ||||
return data_dir | return data_dir | ||||
@@ -185,10 +187,10 @@ class IMDBLoader(Loader): | |||||
"...", "..." | "...", "..." | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(IMDBLoader, self).__init__() | super(IMDBLoader, self).__init__() | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
dataset = DataSet() | dataset = DataSet() | ||||
with open(path, 'r', encoding="utf-8") as f: | with open(path, 'r', encoding="utf-8") as f: | ||||
@@ -201,12 +203,12 @@ class IMDBLoader(Loader): | |||||
words = parts[1] | words = parts[1] | ||||
if words: | if words: | ||||
dataset.append(Instance(raw_words=words, target=target)) | dataset.append(Instance(raw_words=words, target=target)) | ||||
if len(dataset) == 0: | if len(dataset) == 0: | ||||
raise RuntimeError(f"{path} has no valid data.") | raise RuntimeError(f"{path} has no valid data.") | ||||
return dataset | return dataset | ||||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | def download(self, dev_ratio: float = 0.1, seed: int = 0): | ||||
""" | """ | ||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
@@ -221,9 +223,9 @@ class IMDBLoader(Loader): | |||||
""" | """ | ||||
dataset_name = 'aclImdb' | dataset_name = 'aclImdb' | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||||
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||||
re_download = True | re_download = True | ||||
if dev_ratio>0: | |||||
if dev_ratio > 0: | |||||
dev_line_count = 0 | dev_line_count = 0 | ||||
tr_line_count = 0 | tr_line_count = 0 | ||||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ | with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ | ||||
@@ -232,14 +234,14 @@ class IMDBLoader(Loader): | |||||
tr_line_count += 1 | tr_line_count += 1 | ||||
for line in f2: | for line in f2: | ||||
dev_line_count += 1 | dev_line_count += 1 | ||||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||||
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||||
re_download = True | re_download = True | ||||
else: | else: | ||||
re_download = False | re_download = False | ||||
if re_download: | if re_download: | ||||
shutil.rmtree(data_dir) | shutil.rmtree(data_dir) | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | ||||
if dev_ratio > 0: | if dev_ratio > 0: | ||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | ||||
@@ -258,7 +260,7 @@ class IMDBLoader(Loader): | |||||
finally: | finally: | ||||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | ||||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | os.remove(os.path.join(data_dir, 'middle_file.txt')) | ||||
return data_dir | return data_dir | ||||
@@ -278,10 +280,10 @@ class SSTLoader(Loader): | |||||
raw_words列是str。 | raw_words列是str。 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
""" | """ | ||||
从path读取SST文件 | 从path读取SST文件 | ||||
@@ -296,7 +298,7 @@ class SSTLoader(Loader): | |||||
if line: | if line: | ||||
ds.append(Instance(raw_words=line)) | ds.append(Instance(raw_words=line)) | ||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | """ | ||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
@@ -323,10 +325,10 @@ class SST2Loader(Loader): | |||||
test的DataSet没有target列。 | test的DataSet没有target列。 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
""" | """ | ||||
从path读取SST2文件 | 从path读取SST2文件 | ||||
@@ -335,7 +337,7 @@ class SST2Loader(Loader): | |||||
:return: DataSet | :return: DataSet | ||||
""" | """ | ||||
ds = DataSet() | ds = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if 'test' in os.path.split(path)[1]: | if 'test' in os.path.split(path)[1]: | ||||
@@ -356,7 +358,7 @@ class SST2Loader(Loader): | |||||
if raw_words: | if raw_words: | ||||
ds.append(Instance(raw_words=raw_words, target=target)) | ds.append(Instance(raw_words=raw_words, target=target)) | ||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | """ | ||||
自动下载数据集,如果你使用了该数据集,请引用以下的文章 | 自动下载数据集,如果你使用了该数据集,请引用以下的文章 | ||||
@@ -2,17 +2,21 @@ from ...core.dataset import DataSet | |||||
from .. import DataBundle | from .. import DataBundle | ||||
from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
import os | |||||
from ..file_utils import _get_dataset_url, get_cache_path, cached_path | from ..file_utils import _get_dataset_url, get_cache_path, cached_path | ||||
class Loader: | class Loader: | ||||
""" | |||||
各种数据 Loader 的基类,提供了 API 的参考. | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
def _load(self, path:str) -> DataSet: | |||||
def _load(self, path: str) -> DataSet: | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||||
""" | """ | ||||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | ||||
@@ -22,31 +26,25 @@ class Loader: | |||||
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | (0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | ||||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | ||||
名包含'train'、 'dev'、 'test'则会报错 | |||||
Example:: | |||||
名包含'train'、 'dev'、 'test'则会报错:: | |||||
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | ||||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | # dev、 test等有所变化,可以通过以下的方式取出DataSet | ||||
tr_data = data_bundle.datasets['train'] | tr_data = data_bundle.datasets['train'] | ||||
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | ||||
(2) 传入文件路径 | |||||
Example:: | |||||
(2) 传入文件路径:: | |||||
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | ||||
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | ||||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test | |||||
Example:: | |||||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | ||||
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | ||||
dev_data = data_bundle.datasets['dev'] | dev_data = data_bundle.datasets['dev'] | ||||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||||
:return: 返回的 :class:`~fastNLP.io.DataBundle` | |||||
""" | """ | ||||
if paths is None: | if paths is None: | ||||
paths = self.download() | paths = self.download() | ||||
@@ -54,10 +52,10 @@ class Loader: | |||||
datasets = {name: self._load(path) for name, path in paths.items()} | datasets = {name: self._load(path) for name, path in paths.items()} | ||||
data_bundle = DataBundle(datasets=datasets) | data_bundle = DataBundle(datasets=datasets) | ||||
return data_bundle | return data_bundle | ||||
def download(self): | def download(self): | ||||
raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | ||||
def _get_dataset_path(self, dataset_name): | def _get_dataset_path(self, dataset_name): | ||||
""" | """ | ||||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | ||||
@@ -65,11 +63,9 @@ class Loader: | |||||
:param str dataset_name: 数据集的名称 | :param str dataset_name: 数据集的名称 | ||||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | ||||
""" | """ | ||||
default_cache_path = get_cache_path() | default_cache_path = get_cache_path() | ||||
url = _get_dataset_url(dataset_name) | url = _get_dataset_url(dataset_name) | ||||
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | ||||
return output_dir | return output_dir | ||||
@@ -203,7 +203,8 @@ class QNLILoader(JsonLoader): | |||||
""" | """ | ||||
如果您的实验使用到了该数据,请引用 | 如果您的实验使用到了该数据,请引用 | ||||
TODO 补充 | |||||
.. todo:: | |||||
补充 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -8,7 +8,7 @@ __all__ = [ | |||||
import torch | import torch | ||||
from .base_loader import BaseLoader | |||||
from .data_bundle import BaseLoader | |||||
class ModelLoader(BaseLoader): | class ModelLoader(BaseLoader): | ||||
@@ -1,6 +1,6 @@ | |||||
from nltk import Tree | from nltk import Tree | ||||
from ..base_loader import DataBundle | |||||
from ..data_bundle import DataBundle | |||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...core.const import Const | from ...core.const import Const | ||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | ||||
@@ -1,188 +1,188 @@ | |||||
import pickle | |||||
import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.base_loader import DataBundle | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.core.const import Const | |||||
from tools.logger import * | |||||
WORD_PAD = "[PAD]" | |||||
WORD_UNK = "[UNK]" | |||||
DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | |||||
class SummarizationLoader(JsonLoader): | |||||
""" | |||||
读取summarization数据集,读取的DataSet包含fields:: | |||||
text: list(str),document | |||||
summary: list(str), summary | |||||
text_wd: list(list(str)),tokenized document | |||||
summary_wd: list(list(str)), tokenized summary | |||||
labels: list(int), | |||||
flatten_label: list(int), 0 or 1, flatten labels | |||||
domain: str, optional | |||||
tag: list(str), optional | |||||
数据来源: CNN_DailyMail Newsroom DUC | |||||
""" | |||||
def __init__(self): | |||||
super(SummarizationLoader, self).__init__() | |||||
def _load(self, path): | |||||
ds = super(SummarizationLoader, self)._load(path) | |||||
def _lower_text(text_list): | |||||
return [text.lower() for text in text_list] | |||||
def _split_list(text_list): | |||||
return [text.split() for text in text_list] | |||||
def _convert_label(label, sent_len): | |||||
np_label = np.zeros(sent_len, dtype=int) | |||||
if label != []: | |||||
np_label[np.array(label)] = 1 | |||||
return np_label.tolist() | |||||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||||
return ds | |||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||||
""" | |||||
:param paths: dict path for each dataset | |||||
:param vocab_size: int max_size for vocab | |||||
:param vocab_path: str vocab path | |||||
:param sent_max_len: int max token number of the sentence | |||||
:param doc_max_timesteps: int max sentence number of the document | |||||
:param domain: bool build vocab for publication, use 'X' for unknown | |||||
:param tag: bool build vocab for tag, use 'X' for unknown | |||||
:param load_vocab_file: bool build vocab (False) or load vocab (True) | |||||
:return: DataBundle | |||||
datasets: dict keys correspond to the paths dict | |||||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||||
embeddings: optional | |||||
""" | |||||
def _pad_sent(text_wd): | |||||
pad_text_wd = [] | |||||
for sent_wd in text_wd: | |||||
if len(sent_wd) < sent_max_len: | |||||
pad_num = sent_max_len - len(sent_wd) | |||||
sent_wd.extend([WORD_PAD] * pad_num) | |||||
else: | |||||
sent_wd = sent_wd[:sent_max_len] | |||||
pad_text_wd.append(sent_wd) | |||||
return pad_text_wd | |||||
def _token_mask(text_wd): | |||||
token_mask_list = [] | |||||
for sent_wd in text_wd: | |||||
token_num = len(sent_wd) | |||||
if token_num < sent_max_len: | |||||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||||
else: | |||||
mask = [1] * sent_max_len | |||||
token_mask_list.append(mask) | |||||
return token_mask_list | |||||
def _pad_label(label): | |||||
text_len = len(label) | |||||
if text_len < doc_max_timesteps: | |||||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_label = label[:doc_max_timesteps] | |||||
return pad_label | |||||
def _pad_doc(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
padding = [WORD_PAD] * sent_max_len | |||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_text = text_wd[:doc_max_timesteps] | |||||
return pad_text | |||||
def _sent_mask(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
sent_mask = [1] * doc_max_timesteps | |||||
return sent_mask | |||||
datasets = {} | |||||
train_ds = None | |||||
for key, value in paths.items(): | |||||
ds = self.load(value) | |||||
# pad sent | |||||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||||
# pad document | |||||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||||
# rename field | |||||
ds.rename_field("pad_text", Const.INPUT) | |||||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||||
ds.rename_field("pad_label", Const.TARGET) | |||||
# set input and target | |||||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
datasets[key] = ds | |||||
if "train" in key: | |||||
train_ds = datasets[key] | |||||
vocab_dict = {} | |||||
if load_vocab_file == False: | |||||
logger.info("[INFO] Build new vocab from training dataset!") | |||||
if train_ds == None: | |||||
raise ValueError("Lack train file to build vocabulary!") | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||||
vocab_dict["vocab"] = vocabs | |||||
else: | |||||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||||
word_list = [] | |||||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||||
cnt = 2 # pad and unk | |||||
for line in vocab_f: | |||||
pieces = line.split("\t") | |||||
word_list.append(pieces[0]) | |||||
cnt += 1 | |||||
if cnt > vocab_size: | |||||
break | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.add_word_lst(word_list) | |||||
vocabs.build_vocab() | |||||
vocab_dict["vocab"] = vocabs | |||||
if domain == True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||||
domaindict.from_dataset(train_ds, field_name="publication") | |||||
vocab_dict["domain"] = domaindict | |||||
if tag == True: | |||||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||||
tagdict.from_dataset(train_ds, field_name="tag") | |||||
vocab_dict["tag"] = tagdict | |||||
for ds in datasets.values(): | |||||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||||
import pickle | |||||
import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.core.const import Const | |||||
from tools.logger import * | |||||
WORD_PAD = "[PAD]" | |||||
WORD_UNK = "[UNK]" | |||||
DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | |||||
class SummarizationLoader(JsonLoader): | |||||
""" | |||||
读取summarization数据集,读取的DataSet包含fields:: | |||||
text: list(str),document | |||||
summary: list(str), summary | |||||
text_wd: list(list(str)),tokenized document | |||||
summary_wd: list(list(str)), tokenized summary | |||||
labels: list(int), | |||||
flatten_label: list(int), 0 or 1, flatten labels | |||||
domain: str, optional | |||||
tag: list(str), optional | |||||
数据来源: CNN_DailyMail Newsroom DUC | |||||
""" | |||||
def __init__(self): | |||||
super(SummarizationLoader, self).__init__() | |||||
def _load(self, path): | |||||
ds = super(SummarizationLoader, self)._load(path) | |||||
def _lower_text(text_list): | |||||
return [text.lower() for text in text_list] | |||||
def _split_list(text_list): | |||||
return [text.split() for text in text_list] | |||||
def _convert_label(label, sent_len): | |||||
np_label = np.zeros(sent_len, dtype=int) | |||||
if label != []: | |||||
np_label[np.array(label)] = 1 | |||||
return np_label.tolist() | |||||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||||
return ds | |||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||||
""" | |||||
:param paths: dict path for each dataset | |||||
:param vocab_size: int max_size for vocab | |||||
:param vocab_path: str vocab path | |||||
:param sent_max_len: int max token number of the sentence | |||||
:param doc_max_timesteps: int max sentence number of the document | |||||
:param domain: bool build vocab for publication, use 'X' for unknown | |||||
:param tag: bool build vocab for tag, use 'X' for unknown | |||||
:param load_vocab_file: bool build vocab (False) or load vocab (True) | |||||
:return: DataBundle | |||||
datasets: dict keys correspond to the paths dict | |||||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||||
embeddings: optional | |||||
""" | |||||
def _pad_sent(text_wd): | |||||
pad_text_wd = [] | |||||
for sent_wd in text_wd: | |||||
if len(sent_wd) < sent_max_len: | |||||
pad_num = sent_max_len - len(sent_wd) | |||||
sent_wd.extend([WORD_PAD] * pad_num) | |||||
else: | |||||
sent_wd = sent_wd[:sent_max_len] | |||||
pad_text_wd.append(sent_wd) | |||||
return pad_text_wd | |||||
def _token_mask(text_wd): | |||||
token_mask_list = [] | |||||
for sent_wd in text_wd: | |||||
token_num = len(sent_wd) | |||||
if token_num < sent_max_len: | |||||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||||
else: | |||||
mask = [1] * sent_max_len | |||||
token_mask_list.append(mask) | |||||
return token_mask_list | |||||
def _pad_label(label): | |||||
text_len = len(label) | |||||
if text_len < doc_max_timesteps: | |||||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_label = label[:doc_max_timesteps] | |||||
return pad_label | |||||
def _pad_doc(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
padding = [WORD_PAD] * sent_max_len | |||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_text = text_wd[:doc_max_timesteps] | |||||
return pad_text | |||||
def _sent_mask(text_wd): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
sent_mask = [1] * doc_max_timesteps | |||||
return sent_mask | |||||
datasets = {} | |||||
train_ds = None | |||||
for key, value in paths.items(): | |||||
ds = self.load(value) | |||||
# pad sent | |||||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||||
# pad document | |||||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||||
# rename field | |||||
ds.rename_field("pad_text", Const.INPUT) | |||||
ds.rename_field("seq_len", Const.INPUT_LEN) | |||||
ds.rename_field("pad_label", Const.TARGET) | |||||
# set input and target | |||||
ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
datasets[key] = ds | |||||
if "train" in key: | |||||
train_ds = datasets[key] | |||||
vocab_dict = {} | |||||
if load_vocab_file == False: | |||||
logger.info("[INFO] Build new vocab from training dataset!") | |||||
if train_ds == None: | |||||
raise ValueError("Lack train file to build vocabulary!") | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||||
vocab_dict["vocab"] = vocabs | |||||
else: | |||||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||||
word_list = [] | |||||
with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||||
cnt = 2 # pad and unk | |||||
for line in vocab_f: | |||||
pieces = line.split("\t") | |||||
word_list.append(pieces[0]) | |||||
cnt += 1 | |||||
if cnt > vocab_size: | |||||
break | |||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.add_word_lst(word_list) | |||||
vocabs.build_vocab() | |||||
vocab_dict["vocab"] = vocabs | |||||
if domain == True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||||
domaindict.from_dataset(train_ds, field_name="publication") | |||||
vocab_dict["domain"] = domaindict | |||||
if tag == True: | |||||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||||
tagdict.from_dataset(train_ds, field_name="tag") | |||||
vocab_dict["tag"] = tagdict | |||||
for ds in datasets.values(): | |||||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||||
@@ -3,7 +3,7 @@ from datetime import timedelta | |||||
from fastNLP.io.dataset_loader import JsonLoader | from fastNLP.io.dataset_loader import JsonLoader | ||||
from fastNLP.modules.encoder._bert import BertTokenizer | from fastNLP.modules.encoder._bert import BertTokenizer | ||||
from fastNLP.io.base_loader import DataBundle | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
class BertData(JsonLoader): | class BertData(JsonLoader): | ||||
@@ -1,7 +1,7 @@ | |||||
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | ||||
from fastNLP.io.file_reader import _read_json | from fastNLP.io.file_reader import _read_json | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.base_loader import DataBundle | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
import reproduction.coreference_resolution.model.preprocess as preprocess | import reproduction.coreference_resolution.model.preprocess as preprocess | ||||
@@ -1,6 +1,6 @@ | |||||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_loader import ConllLoader | from fastNLP.io.data_loader import ConllLoader | ||||
import numpy as np | import numpy as np | ||||
@@ -9,7 +9,7 @@ from typing import Union, Dict | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||||
from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | ||||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | ||||
from fastNLP.modules.encoder._bert import BertTokenizer | from fastNLP.modules.encoder._bert import BertTokenizer | ||||
@@ -1,6 +1,6 @@ | |||||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from fastNLP.io import ConllLoader | from fastNLP.io import ConllLoader | ||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | ||||
from fastNLP import Const | from fastNLP import Const | ||||
@@ -1,7 +1,7 @@ | |||||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | ||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict, List, Iterator | from typing import Union, Dict, List, Iterator | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
@@ -1,6 +1,6 @@ | |||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
from fastNLP import Const | from fastNLP import Const | ||||
@@ -1,5 +1,5 @@ | |||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
@@ -1,6 +1,6 @@ | |||||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | ||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict, List, Iterator | from typing import Union, Dict, List, Iterator | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
@@ -1,6 +1,6 @@ | |||||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | ||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict, List, Iterator | from typing import Union, Dict, List, Iterator | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
@@ -1,6 +1,6 @@ | |||||
from typing import Iterable | from typing import Iterable | ||||
from nltk import Tree | from nltk import Tree | ||||
from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||||
from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||||
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
@@ -4,7 +4,7 @@ from typing import Iterable | |||||
from fastNLP import DataSet, Instance, Vocabulary | from fastNLP import DataSet, Instance, Vocabulary | ||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.io import JsonLoader | from fastNLP.io import JsonLoader | ||||
from fastNLP.io.base_loader import DataBundle,DataSetLoader | |||||
from fastNLP.io.data_bundle import DataBundle,DataSetLoader | |||||
from fastNLP.io.embed_loader import EmbeddingOption | from fastNLP.io.embed_loader import EmbeddingOption | ||||
from fastNLP.io.file_reader import _read_json | from fastNLP.io.file_reader import _read_json | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||