@@ -12,13 +12,14 @@ | |||
这些类的使用方法可以在对应module的文档下查看. | |||
""" | |||
from .embed_loader import EmbedLoader | |||
from .dataset_loader import * | |||
from .config_io import * | |||
from .model_io import * | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||
PeopleDailyCorpusLoader, Conll2003Loader | |||
from .config_io import ConfigLoader, ConfigSection, ConfigSaver | |||
from .model_io import ModelLoader as ModelLoader, ModelSaver as ModelSaver | |||
__all__ = [ | |||
'EmbedLoader', | |||
'DataSetLoader', | |||
'CSVLoader', | |||
'JsonLoader', | |||
@@ -27,11 +28,11 @@ __all__ = [ | |||
'SSTLoader', | |||
'PeopleDailyCorpusLoader', | |||
'Conll2003Loader', | |||
'ConfigLoader', | |||
'ConfigSection', | |||
'ConfigSaver', | |||
'ModelLoader', | |||
'ModelSaver', | |||
] | |||
] |
@@ -3,7 +3,8 @@ import os | |||
class BaseLoader(object): | |||
"""Base loader for all loaders. | |||
""" | |||
各个 Loader 的基类,提供了 API 的参考。 | |||
""" | |||
def __init__(self): | |||
@@ -11,7 +12,10 @@ class BaseLoader(object): | |||
@staticmethod | |||
def load_lines(data_path): | |||
"""按行读取,舍弃每行两侧空白字符,返回list of str | |||
""" | |||
按行读取,舍弃每行两侧空白字符,返回list of str | |||
:param data_path: 读取数据的路径 | |||
""" | |||
with open(data_path, "r", encoding="utf=8") as f: | |||
text = f.readlines() | |||
@@ -19,7 +23,10 @@ class BaseLoader(object): | |||
@classmethod | |||
def load(cls, data_path): | |||
"""先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str | |||
""" | |||
先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str | |||
:param data_path: | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
text = f.readlines() | |||
@@ -40,9 +47,7 @@ class BaseLoader(object): | |||
class DataLoaderRegister: | |||
"""Register for all data sets. | |||
""" | |||
# TODO 这个类使用在何处? | |||
_readers = {} | |||
@classmethod | |||
@@ -1,19 +1,22 @@ | |||
""" | |||
.. _config-io: | |||
用于读入和处理和保存 config 文件 | |||
""" | |||
__all__ = ["ConfigLoader","ConfigSection","ConfigSaver"] | |||
import configparser | |||
import json | |||
import os | |||
from fastNLP.io.base_loader import BaseLoader | |||
from .base_loader import BaseLoader | |||
class ConfigLoader(BaseLoader): | |||
"""Loader for configuration. | |||
""" | |||
别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader` | |||
读取配置文件的Loader | |||
:param str data_path: path to the config | |||
:param str data_path: 配置文件的路径 | |||
""" | |||
def __init__(self, data_path=None): | |||
@@ -27,14 +30,16 @@ class ConfigLoader(BaseLoader): | |||
@staticmethod | |||
def load_config(file_path, sections): | |||
"""Load section(s) of configuration into the ``sections`` provided. No returns. | |||
""" | |||
把配置文件的section 存入提供的 ``sections`` 中 | |||
:param str file_path: the path of config file | |||
:param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | |||
Example:: | |||
test_args = ConfigSection() | |||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
:param str file_path: 配置文件的路径 | |||
:param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection` | |||
Example:: | |||
test_args = ConfigSection() | |||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
""" | |||
assert isinstance(sections, dict) | |||
@@ -70,7 +75,10 @@ class ConfigLoader(BaseLoader): | |||
class ConfigSection(object): | |||
"""ConfigSection is the data structure storing all key-value pairs in one section in a config file. | |||
""" | |||
别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection` | |||
ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 | |||
""" | |||
@@ -146,9 +154,12 @@ class ConfigSection(object): | |||
class ConfigSaver(object): | |||
"""ConfigSaver is used to save config file and solve related conflicts. | |||
""" | |||
别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver` | |||
ConfigSaver 是用来存储配置文件并解决相关冲突的类 | |||
:param str file_path: path to the config file | |||
:param str file_path: 配置文件的路径 | |||
""" | |||
def __init__(self, file_path): | |||
@@ -157,7 +168,8 @@ class ConfigSaver(object): | |||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||
def _get_section(self, sect_name): | |||
"""This is the function to get the section with the section name. | |||
""" | |||
This is the function to get the section with the section name. | |||
:param sect_name: The name of section what wants to load. | |||
:return: The section. | |||
@@ -167,7 +179,8 @@ class ConfigSaver(object): | |||
return sect | |||
def _read_section(self): | |||
"""This is the function to read sections from the config file. | |||
""" | |||
This is the function to read sections from the config file. | |||
:return: sect_list, sect_key_list | |||
sect_list: A list of ConfigSection(). | |||
@@ -219,7 +232,8 @@ class ConfigSaver(object): | |||
return sect_list, sect_key_list | |||
def _write_section(self, sect_list, sect_key_list): | |||
"""This is the function to write config file with section list and name list. | |||
""" | |||
This is the function to write config file with section list and name list. | |||
:param sect_list: A list of ConfigSection() need to be writen into file. | |||
:param sect_key_list: A list of name of sect_list. | |||
@@ -240,10 +254,11 @@ class ConfigSaver(object): | |||
f.write('\n') | |||
def save_config_file(self, section_name, section): | |||
"""This is the function to be called to change the config file with a single section and its name. | |||
""" | |||
这个方法可以用来修改并保存配置文件中单独的一个 section | |||
:param str section_name: The name of section what needs to be changed and saved. | |||
:param ConfigSection section: The section with key and value what needs to be changed and saved. | |||
:param str section_name: 需要保存的 section 的名字. | |||
:param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型 | |||
""" | |||
section_file = self._get_section(section_name) | |||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||
@@ -1,8 +1,6 @@ | |||
""" | |||
.. _dataset-loader: | |||
DataSetLoader 的 API, 用于读取不同格式的数据, 并返回 `DataSet` , | |||
得到的 `DataSet` 对象可以直接传入 `Trainer`, `Tester`, 用于模型的训练和测试 | |||
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , | |||
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer`, :class:`~fastNLP.Tester`, 用于模型的训练和测试 | |||
Example:: | |||
@@ -13,50 +11,50 @@ Example:: | |||
# ... do stuff | |||
""" | |||
import os | |||
import json | |||
from nltk.tree import Tree | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.io.file_reader import _read_csv, _read_json, _read_conll | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from .file_reader import _read_csv, _read_json, _read_conll | |||
def _download_from_url(url, path): | |||
from tqdm import tqdm | |||
import requests | |||
"""Download file""" | |||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||
chunk_size = 16 * 1024 | |||
total_size = int(r.headers.get('Content-length', 0)) | |||
with open(path, "wb") as file ,\ | |||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||
with open(path, "wb") as file, \ | |||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||
for chunk in r.iter_content(chunk_size): | |||
if chunk: | |||
file.write(chunk) | |||
t.update(len(chunk)) | |||
return | |||
def _uncompress(src, dst): | |||
import zipfile, gzip, tarfile, os | |||
def unzip(src, dst): | |||
with zipfile.ZipFile(src, 'r') as f: | |||
f.extractall(dst) | |||
def ungz(src, dst): | |||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||
length = 16 * 1024 # 16KB | |||
length = 16 * 1024 # 16KB | |||
buf = f.read(length) | |||
while buf: | |||
uf.write(buf) | |||
buf = f.read(length) | |||
def untar(src, dst): | |||
with tarfile.open(src, 'r:gz') as f: | |||
f.extractall(dst) | |||
fn, ext = os.path.splitext(src) | |||
_, ext_2 = os.path.splitext(fn) | |||
if ext == '.zip': | |||
@@ -71,42 +69,48 @@ def _uncompress(src, dst): | |||
class DataSetLoader: | |||
""" | |||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||
所有`DataSetLoader`的接口 | |||
所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader | |||
""" | |||
def load(self, path): | |||
"""从指定 ``path`` 的文件中读取数据,返回DataSet | |||
:param str path: file path | |||
:return: a DataSet object | |||
:param str path: 文件路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
raise NotImplementedError | |||
def convert(self, data): | |||
"""用Python数据对象创建DataSet | |||
""" | |||
用Python数据对象创建DataSet,各个子类需要自行实现这个方法 | |||
:param data: inner data structure (user-defined) to represent the data. | |||
:return: a DataSet object | |||
:param data: Python 内置的数据结构 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
raise NotImplementedError | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
"""读取人民日报数据集 | |||
""" | |||
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader` | |||
读取人民日报数据集 | |||
""" | |||
def __init__(self): | |||
super(PeopleDailyCorpusLoader, self).__init__() | |||
self.pos = True | |||
self.ner = True | |||
def load(self, data_path, pos=True, ner=True): | |||
""" | |||
:param str data_path: 数据路径 | |||
:param bool pos: 是否使用词性标签 | |||
:param bool ner: 是否使用命名实体标签 | |||
:return: a DataSet object | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
self.pos, self.ner = pos, ner | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
@@ -152,8 +156,13 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
example.append(sent_ner) | |||
examples.append(example) | |||
return self.convert(examples) | |||
def convert(self, data): | |||
""" | |||
:param data: python 内置对象 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
data_set = DataSet() | |||
for item in data: | |||
sent_words = item[0] | |||
@@ -172,6 +181,8 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
class ConllLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | |||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html | |||
列号从0开始, 每列对应内容为:: | |||
@@ -195,6 +206,7 @@ class ConllLoader(DataSetLoader): | |||
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||
""" | |||
def __init__(self, headers, indexs=None, dropna=False): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
@@ -207,21 +219,25 @@ class ConllLoader(DataSetLoader): | |||
if len(indexs) != len(headers): | |||
raise ValueError | |||
self.indexs = indexs | |||
def load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexs, dropna=self.dropna): | |||
ins = {h:data[i] for i, h in enumerate(self.headers)} | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
class Conll2003Loader(ConllLoader): | |||
"""读取Conll2003数据 | |||
""" | |||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader` | |||
读取Conll2003数据 | |||
关于数据集的更多信息,参考: | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'tokens', 'pos', 'chunks', 'ner', | |||
@@ -260,7 +276,10 @@ def _cut_long_sentence(sent, max_sample_length=200): | |||
class SSTLoader(DataSetLoader): | |||
"""读取SST数据集, DataSet包含fields:: | |||
""" | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||
读取SST数据集, DataSet包含fields:: | |||
words: list(str) 需要分类的文本 | |||
target: str 文本的标签 | |||
@@ -270,21 +289,22 @@ class SSTLoader(DataSetLoader): | |||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
""" | |||
def __init__(self, subtree=False, fine_grained=False): | |||
self.subtree = subtree | |||
tag_v = {'0':'very negative', '1':'negative', '2':'neutral', | |||
'3':'positive', '4':'very positive'} | |||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||
'3': 'positive', '4': 'very positive'} | |||
if not fine_grained: | |||
tag_v['0'] = tag_v['1'] | |||
tag_v['4'] = tag_v['3'] | |||
self.tag_v = tag_v | |||
def load(self, path): | |||
""" | |||
:param path: str,存储数据的路径 | |||
:return: DataSet。 | |||
:param str path: 存储数据的路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
@@ -296,7 +316,7 @@ class SSTLoader(DataSetLoader): | |||
for words, tag in datas: | |||
ds.append(Instance(words=words, target=tag)) | |||
return ds | |||
@staticmethod | |||
def _get_one(data, subtree): | |||
tree = Tree.fromstring(data) | |||
@@ -307,15 +327,18 @@ class SSTLoader(DataSetLoader): | |||
class JsonLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` | |||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||
``fields`` 的`key`必须是json对象的属性名. ``fields`` 的`value`为读入后在DataSet存储的`field_name`, | |||
`value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 | |||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , | |||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 | |||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super(JsonLoader, self).__init__() | |||
self.dropna = dropna | |||
@@ -326,12 +349,12 @@ class JsonLoader(DataSetLoader): | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.fields_list = list(self.fields.keys()) | |||
def load(self, path): | |||
ds = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
if self.fields: | |||
ins = {self.fields[k]:v for k,v in d.items()} | |||
ins = {self.fields[k]: v for k, v in d.items()} | |||
else: | |||
ins = d | |||
ds.append(Instance(**ins)) | |||
@@ -340,6 +363,8 @@ class JsonLoader(DataSetLoader): | |||
class SNLILoader(JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
@@ -348,6 +373,7 @@ class SNLILoader(JsonLoader): | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self): | |||
fields = { | |||
'sentence1_parse': 'words1', | |||
@@ -355,12 +381,14 @@ class SNLILoader(JsonLoader): | |||
'gold_label': 'target', | |||
} | |||
super(SNLILoader, self).__init__(fields=fields) | |||
def load(self, path): | |||
ds = super(SNLILoader, self).load(path) | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') | |||
ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') | |||
ds.drop(lambda x: x['target'] == '-') | |||
@@ -369,6 +397,8 @@ class SNLILoader(JsonLoader): | |||
class CSVLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | |||
读取CSV格式的数据集。返回 ``DataSet`` | |||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||
@@ -377,11 +407,12 @@ class CSVLoader(DataSetLoader): | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=False): | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
def load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_csv(path, headers=self.headers, | |||
@@ -396,7 +427,7 @@ def _add_seg_tag(data): | |||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||
:return: list of ([word], [pos]) | |||
""" | |||
_processed = [] | |||
for word_list, pos_list, _, _ in data: | |||
new_sample = [] | |||
@@ -410,4 +441,3 @@ def _add_seg_tag(data): | |||
new_sample.append((word[-1], 'E-' + pos)) | |||
_processed.append(list(map(list, zip(*new_sample)))) | |||
return _processed | |||
@@ -7,13 +7,17 @@ import os | |||
import numpy as np | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import BaseLoader | |||
from ..core.vocabulary import Vocabulary | |||
from .base_loader import BaseLoader | |||
import warnings | |||
class EmbedLoader(BaseLoader): | |||
"""这个类用于从预训练的Embedding中load数据。""" | |||
""" | |||
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` | |||
这个类用于从预训练的Embedding中load数据。 | |||
""" | |||
def __init__(self): | |||
super(EmbedLoader, self).__init__() | |||
@@ -25,13 +29,13 @@ class EmbedLoader(BaseLoader): | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param Vocabulary vocab: 词表,读取出现在vocab中的词的embedding。没有出现在vocab中的词的embedding将通过找到的词的 | |||
embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||
:param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 | |||
没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||
:param dtype: 读出的embedding的类型 | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 | |||
方在于词表有空行或者词表出现了维度不一致。 | |||
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 | |||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 | |||
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
if not os.path.exists(embed_filepath): | |||
@@ -87,11 +91,11 @@ class EmbedLoader(BaseLoader): | |||
:param str padding: the padding tag for vocabulary. | |||
:param str unknown: the unknown tag for vocabulary. | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 | |||
方在于词表有空行或者词表出现了维度不一致。 | |||
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
:return: numpy.ndarray,Vocabulary embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||
是否使用padding, 以及unknown有没有在词表中找到对应的词。Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | |||
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
:return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||
是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | |||
""" | |||
vocab = Vocabulary(padding=padding, unknown=unknown) | |||
vec_dict = {} | |||
@@ -1,3 +1,6 @@ | |||
""" | |||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | |||
""" | |||
import json | |||
@@ -1,16 +1,16 @@ | |||
""" | |||
.. _model-io: | |||
用于载入和保存模型 | |||
""" | |||
import torch | |||
from fastNLP.io.base_loader import BaseLoader | |||
from .base_loader import BaseLoader | |||
class ModelLoader(BaseLoader): | |||
""" | |||
Loader for models. | |||
别名::class:`fastNLP.io.ModelLoader` :class:`fastNLP.io.model_io.ModelLoader` | |||
用于读取模型 | |||
""" | |||
def __init__(self): | |||
@@ -18,24 +18,30 @@ class ModelLoader(BaseLoader): | |||
@staticmethod | |||
def load_pytorch(empty_model, model_path): | |||
"""Load model parameters from ".pkl" files into the empty PyTorch model. | |||
""" | |||
从 ".pkl" 文件读取 PyTorch 模型 | |||
:param empty_model: a PyTorch model with initialized parameters. | |||
:param str model_path: the path to the saved model. | |||
:param empty_model: 初始化参数的 PyTorch 模型 | |||
:param str model_path: 模型保存的路径 | |||
""" | |||
empty_model.load_state_dict(torch.load(model_path)) | |||
@staticmethod | |||
def load_pytorch_model(model_path): | |||
"""Load the entire model. | |||
""" | |||
读取整个模型 | |||
:param str model_path: the path to the saved model. | |||
:param str model_path: 模型保存的路径 | |||
""" | |||
return torch.load(model_path) | |||
class ModelSaver(object): | |||
"""Save a model | |||
""" | |||
别名::class:`fastNLP.io.ModelSaver` :class:`fastNLP.io.model_io.ModelSaver` | |||
用于保存模型 | |||
Example:: | |||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||
@@ -46,15 +52,16 @@ class ModelSaver(object): | |||
def __init__(self, save_path): | |||
""" | |||
:param save_path: the path to the saving directory. | |||
:param save_path: 模型保存的路径 | |||
""" | |||
self.save_path = save_path | |||
def save_pytorch(self, model, param_only=True): | |||
"""Save a pytorch model into ".pkl" file. | |||
""" | |||
把 PyTorch 模型存入 ".pkl" 文件 | |||
:param model: a PyTorch model | |||
:param bool param_only: whether only to save the model parameters or the entire model. | |||
:param model: PyTorch 模型 | |||
:param bool param_only: 是否只保存模型的参数(否则保存整个模型) | |||
""" | |||
if param_only is True: | |||