Browse Source

修改了 io 部分的注释

tags/v0.4.10
ChenXin 5 years ago
parent
commit
a1f8cdec48
7 changed files with 172 additions and 107 deletions
  1. +8
    -7
      fastNLP/io/__init__.py
  2. +11
    -6
      fastNLP/io/base_loader.py
  3. +35
    -20
      fastNLP/io/config_io.py
  4. +78
    -48
      fastNLP/io/dataset_loader.py
  5. +16
    -12
      fastNLP/io/embed_loader.py
  6. +3
    -0
      fastNLP/io/file_reader.py
  7. +21
    -14
      fastNLP/io/model_io.py

+ 8
- 7
fastNLP/io/__init__.py View File

@@ -12,13 +12,14 @@
这些类的使用方法可以在对应module的文档下查看. 这些类的使用方法可以在对应module的文档下查看.
""" """
from .embed_loader import EmbedLoader from .embed_loader import EmbedLoader
from .dataset_loader import *
from .config_io import *
from .model_io import *
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader
from .config_io import ConfigLoader, ConfigSection, ConfigSaver
from .model_io import ModelLoader as ModelLoader, ModelSaver as ModelSaver


__all__ = [ __all__ = [
'EmbedLoader', 'EmbedLoader',
'DataSetLoader', 'DataSetLoader',
'CSVLoader', 'CSVLoader',
'JsonLoader', 'JsonLoader',
@@ -27,11 +28,11 @@ __all__ = [
'SSTLoader', 'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
'Conll2003Loader', 'Conll2003Loader',
'ConfigLoader', 'ConfigLoader',
'ConfigSection', 'ConfigSection',
'ConfigSaver', 'ConfigSaver',
'ModelLoader', 'ModelLoader',
'ModelSaver', 'ModelSaver',
]
]

+ 11
- 6
fastNLP/io/base_loader.py View File

@@ -3,7 +3,8 @@ import os




class BaseLoader(object): class BaseLoader(object):
"""Base loader for all loaders.
"""
各个 Loader 的基类,提供了 API 的参考。


""" """
def __init__(self): def __init__(self):
@@ -11,7 +12,10 @@ class BaseLoader(object):


@staticmethod @staticmethod
def load_lines(data_path): def load_lines(data_path):
"""按行读取,舍弃每行两侧空白字符,返回list of str
"""
按行读取,舍弃每行两侧空白字符,返回list of str

:param data_path: 读取数据的路径
""" """
with open(data_path, "r", encoding="utf=8") as f: with open(data_path, "r", encoding="utf=8") as f:
text = f.readlines() text = f.readlines()
@@ -19,7 +23,10 @@ class BaseLoader(object):


@classmethod @classmethod
def load(cls, data_path): def load(cls, data_path):
"""先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str
"""
先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str
:param data_path:
""" """
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
text = f.readlines() text = f.readlines()
@@ -40,9 +47,7 @@ class BaseLoader(object):




class DataLoaderRegister: class DataLoaderRegister:
"""Register for all data sets.

"""
# TODO 这个类使用在何处?
_readers = {} _readers = {}


@classmethod @classmethod


+ 35
- 20
fastNLP/io/config_io.py View File

@@ -1,19 +1,22 @@
""" """
.. _config-io:


用于读入和处理和保存 config 文件 用于读入和处理和保存 config 文件
""" """
__all__ = ["ConfigLoader","ConfigSection","ConfigSaver"]
import configparser import configparser
import json import json
import os import os


from fastNLP.io.base_loader import BaseLoader
from .base_loader import BaseLoader




class ConfigLoader(BaseLoader): class ConfigLoader(BaseLoader):
"""Loader for configuration.
"""
别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader`

读取配置文件的Loader


:param str data_path: path to the config
:param str data_path: 配置文件的路径


""" """
def __init__(self, data_path=None): def __init__(self, data_path=None):
@@ -27,14 +30,16 @@ class ConfigLoader(BaseLoader):


@staticmethod @staticmethod
def load_config(file_path, sections): def load_config(file_path, sections):
"""Load section(s) of configuration into the ``sections`` provided. No returns.
"""
把配置文件的section 存入提供的 ``sections`` 中


:param str file_path: the path of config file
:param dict sections: the dict of ``{section_name(string): ConfigSection object}``
Example::
test_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})
:param str file_path: 配置文件的路径
:param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection`
Example::

test_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})


""" """
assert isinstance(sections, dict) assert isinstance(sections, dict)
@@ -70,7 +75,10 @@ class ConfigLoader(BaseLoader):




class ConfigSection(object): class ConfigSection(object):
"""ConfigSection is the data structure storing all key-value pairs in one section in a config file.
"""
别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection`

ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用


""" """


@@ -146,9 +154,12 @@ class ConfigSection(object):




class ConfigSaver(object): class ConfigSaver(object):
"""ConfigSaver is used to save config file and solve related conflicts.
"""
别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver`

ConfigSaver 是用来存储配置文件并解决相关冲突的类


:param str file_path: path to the config file
:param str file_path: 配置文件的路径


""" """
def __init__(self, file_path): def __init__(self, file_path):
@@ -157,7 +168,8 @@ class ConfigSaver(object):
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) raise FileNotFoundError("file {} NOT found!".__format__(self.file_path))


def _get_section(self, sect_name): def _get_section(self, sect_name):
"""This is the function to get the section with the section name.
"""
This is the function to get the section with the section name.


:param sect_name: The name of section what wants to load. :param sect_name: The name of section what wants to load.
:return: The section. :return: The section.
@@ -167,7 +179,8 @@ class ConfigSaver(object):
return sect return sect


def _read_section(self): def _read_section(self):
"""This is the function to read sections from the config file.
"""
This is the function to read sections from the config file.


:return: sect_list, sect_key_list :return: sect_list, sect_key_list
sect_list: A list of ConfigSection(). sect_list: A list of ConfigSection().
@@ -219,7 +232,8 @@ class ConfigSaver(object):
return sect_list, sect_key_list return sect_list, sect_key_list


def _write_section(self, sect_list, sect_key_list): def _write_section(self, sect_list, sect_key_list):
"""This is the function to write config file with section list and name list.
"""
This is the function to write config file with section list and name list.


:param sect_list: A list of ConfigSection() need to be writen into file. :param sect_list: A list of ConfigSection() need to be writen into file.
:param sect_key_list: A list of name of sect_list. :param sect_key_list: A list of name of sect_list.
@@ -240,10 +254,11 @@ class ConfigSaver(object):
f.write('\n') f.write('\n')


def save_config_file(self, section_name, section): def save_config_file(self, section_name, section):
"""This is the function to be called to change the config file with a single section and its name.
"""
这个方法可以用来修改并保存配置文件中单独的一个 section


:param str section_name: The name of section what needs to be changed and saved.
:param ConfigSection section: The section with key and value what needs to be changed and saved.
:param str section_name: 需要保存的 section 的名字.
:param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型
""" """
section_file = self._get_section(section_name) section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0: # the section not in the file before if len(section_file.__dict__.keys()) == 0: # the section not in the file before


+ 78
- 48
fastNLP/io/dataset_loader.py View File

@@ -1,8 +1,6 @@
""" """
.. _dataset-loader:

DataSetLoader 的 API, 用于读取不同格式的数据, 并返回 `DataSet` ,
得到的 `DataSet` 对象可以直接传入 `Trainer`, `Tester`, 用于模型的训练和测试
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` ,
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer`, :class:`~fastNLP.Tester`, 用于模型的训练和测试


Example:: Example::


@@ -13,50 +11,50 @@ Example::


# ... do stuff # ... do stuff
""" """
import os
import json

from nltk.tree import Tree from nltk.tree import Tree


from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.io.file_reader import _read_csv, _read_json, _read_conll
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll




def _download_from_url(url, path): def _download_from_url(url, path):
from tqdm import tqdm from tqdm import tqdm
import requests import requests
"""Download file""" """Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024 chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0)) total_size = int(r.headers.get('Content-length', 0))
with open(path, "wb") as file ,\
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
with open(path, "wb") as file, \
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
for chunk in r.iter_content(chunk_size): for chunk in r.iter_content(chunk_size):
if chunk: if chunk:
file.write(chunk) file.write(chunk)
t.update(len(chunk)) t.update(len(chunk))
return return



def _uncompress(src, dst): def _uncompress(src, dst):
import zipfile, gzip, tarfile, os import zipfile, gzip, tarfile, os
def unzip(src, dst): def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f: with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst) f.extractall(dst)
def ungz(src, dst): def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
length = 16 * 1024 # 16KB
buf = f.read(length) buf = f.read(length)
while buf: while buf:
uf.write(buf) uf.write(buf)
buf = f.read(length) buf = f.read(length)
def untar(src, dst): def untar(src, dst):
with tarfile.open(src, 'r:gz') as f: with tarfile.open(src, 'r:gz') as f:
f.extractall(dst) f.extractall(dst)
fn, ext = os.path.splitext(src) fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn) _, ext_2 = os.path.splitext(fn)
if ext == '.zip': if ext == '.zip':
@@ -71,42 +69,48 @@ def _uncompress(src, dst):


class DataSetLoader: class DataSetLoader:
""" """
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`


所有`DataSetLoader`的接口
所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader
""" """
def load(self, path): def load(self, path):
"""从指定 ``path`` 的文件中读取数据,返回DataSet """从指定 ``path`` 的文件中读取数据,返回DataSet


:param str path: file path
:return: a DataSet object
:param str path: 文件路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
""" """
raise NotImplementedError raise NotImplementedError
def convert(self, data): def convert(self, data):
"""用Python数据对象创建DataSet
"""
用Python数据对象创建DataSet,各个子类需要自行实现这个方法


:param data: inner data structure (user-defined) to represent the data.
:return: a DataSet object
:param data: Python 内置的数据结构
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
""" """
raise NotImplementedError raise NotImplementedError




class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
"""读取人民日报数据集
""" """
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader`

读取人民日报数据集
"""
def __init__(self): def __init__(self):
super(PeopleDailyCorpusLoader, self).__init__() super(PeopleDailyCorpusLoader, self).__init__()
self.pos = True self.pos = True
self.ner = True self.ner = True

def load(self, data_path, pos=True, ner=True): def load(self, data_path, pos=True, ner=True):
""" """


:param str data_path: 数据路径 :param str data_path: 数据路径
:param bool pos: 是否使用词性标签 :param bool pos: 是否使用词性标签
:param bool ner: 是否使用命名实体标签 :param bool ner: 是否使用命名实体标签
:return: a DataSet object
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
""" """
self.pos, self.ner = pos, ner self.pos, self.ner = pos, ner
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
@@ -152,8 +156,13 @@ class PeopleDailyCorpusLoader(DataSetLoader):
example.append(sent_ner) example.append(sent_ner)
examples.append(example) examples.append(example)
return self.convert(examples) return self.convert(examples)
def convert(self, data): def convert(self, data):
"""
:param data: python 内置对象
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
data_set = DataSet() data_set = DataSet()
for item in data: for item in data:
sent_words = item[0] sent_words = item[0]
@@ -172,6 +181,8 @@ class PeopleDailyCorpusLoader(DataSetLoader):


class ConllLoader(DataSetLoader): class ConllLoader(DataSetLoader):
""" """
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader`

读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html


列号从0开始, 每列对应内容为:: 列号从0开始, 每列对应内容为::
@@ -195,6 +206,7 @@ class ConllLoader(DataSetLoader):
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
""" """
def __init__(self, headers, indexs=None, dropna=False): def __init__(self, headers, indexs=None, dropna=False):
super(ConllLoader, self).__init__() super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)): if not isinstance(headers, (list, tuple)):
@@ -207,21 +219,25 @@ class ConllLoader(DataSetLoader):
if len(indexs) != len(headers): if len(indexs) != len(headers):
raise ValueError raise ValueError
self.indexs = indexs self.indexs = indexs
def load(self, path): def load(self, path):
ds = DataSet() ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexs, dropna=self.dropna): for idx, data in _read_conll(path, indexes=self.indexs, dropna=self.dropna):
ins = {h:data[i] for i, h in enumerate(self.headers)}
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds




class Conll2003Loader(ConllLoader): class Conll2003Loader(ConllLoader):
"""读取Conll2003数据
"""
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader`

读取Conll2003数据
关于数据集的更多信息,参考: 关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
""" """
def __init__(self): def __init__(self):
headers = [ headers = [
'tokens', 'pos', 'chunks', 'ner', 'tokens', 'pos', 'chunks', 'ner',
@@ -260,7 +276,10 @@ def _cut_long_sentence(sent, max_sample_length=200):




class SSTLoader(DataSetLoader): class SSTLoader(DataSetLoader):
"""读取SST数据集, DataSet包含fields::
"""
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
读取SST数据集, DataSet包含fields::


words: list(str) 需要分类的文本 words: list(str) 需要分类的文本
target: str 文本的标签 target: str 文本的标签
@@ -270,21 +289,22 @@ class SSTLoader(DataSetLoader):
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
""" """
def __init__(self, subtree=False, fine_grained=False): def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree self.subtree = subtree
tag_v = {'0':'very negative', '1':'negative', '2':'neutral',
'3':'positive', '4':'very positive'}
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained: if not fine_grained:
tag_v['0'] = tag_v['1'] tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3'] tag_v['4'] = tag_v['3']
self.tag_v = tag_v self.tag_v = tag_v
def load(self, path): def load(self, path):
""" """


:param path: str,存储数据的路径
:return: DataSet。
:param str path: 存储数据的路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
""" """
datalist = [] datalist = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
@@ -296,7 +316,7 @@ class SSTLoader(DataSetLoader):
for words, tag in datas: for words, tag in datas:
ds.append(Instance(words=words, target=tag)) ds.append(Instance(words=words, target=tag))
return ds return ds
@staticmethod @staticmethod
def _get_one(data, subtree): def _get_one(data, subtree):
tree = Tree.fromstring(data) tree = Tree.fromstring(data)
@@ -307,15 +327,18 @@ class SSTLoader(DataSetLoader):


class JsonLoader(DataSetLoader): class JsonLoader(DataSetLoader):
""" """
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader`

读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象


:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
``fields`` 的`key`必须是json对象的属性名. ``fields`` 的`value`为读入后在DataSet存储的`field_name`,
`value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False`` Default: ``False``
""" """
def __init__(self, fields=None, dropna=False): def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__() super(JsonLoader, self).__init__()
self.dropna = dropna self.dropna = dropna
@@ -326,12 +349,12 @@ class JsonLoader(DataSetLoader):
for k, v in fields.items(): for k, v in fields.items():
self.fields[k] = k if v is None else v self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys()) self.fields_list = list(self.fields.keys())
def load(self, path): def load(self, path):
ds = DataSet() ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields: if self.fields:
ins = {self.fields[k]:v for k,v in d.items()}
ins = {self.fields[k]: v for k, v in d.items()}
else: else:
ins = d ins = d
ds.append(Instance(**ins)) ds.append(Instance(**ins))
@@ -340,6 +363,8 @@ class JsonLoader(DataSetLoader):


class SNLILoader(JsonLoader): class SNLILoader(JsonLoader):
""" """
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`

读取SNLI数据集,读取的DataSet包含fields:: 读取SNLI数据集,读取的DataSet包含fields::


words1: list(str),第一句文本, premise words1: list(str),第一句文本, premise
@@ -348,6 +373,7 @@ class SNLILoader(JsonLoader):


数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
""" """
def __init__(self): def __init__(self):
fields = { fields = {
'sentence1_parse': 'words1', 'sentence1_parse': 'words1',
@@ -355,12 +381,14 @@ class SNLILoader(JsonLoader):
'gold_label': 'target', 'gold_label': 'target',
} }
super(SNLILoader, self).__init__(fields=fields) super(SNLILoader, self).__init__(fields=fields)
def load(self, path): def load(self, path):
ds = super(SNLILoader, self).load(path) ds = super(SNLILoader, self).load(path)
def parse_tree(x): def parse_tree(x):
t = Tree.fromstring(x) t = Tree.fromstring(x)
return t.leaves() return t.leaves()
ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1')
ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2')
ds.drop(lambda x: x['target'] == '-') ds.drop(lambda x: x['target'] == '-')
@@ -369,6 +397,8 @@ class SNLILoader(JsonLoader):


class CSVLoader(DataSetLoader): class CSVLoader(DataSetLoader):
""" """
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader`

读取CSV格式的数据集。返回 ``DataSet`` 读取CSV格式的数据集。返回 ``DataSet``


:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称
@@ -377,11 +407,12 @@ class CSVLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False`` Default: ``False``
""" """
def __init__(self, headers=None, sep=",", dropna=False): def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers self.headers = headers
self.sep = sep self.sep = sep
self.dropna = dropna self.dropna = dropna
def load(self, path): def load(self, path):
ds = DataSet() ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers, for idx, data in _read_csv(path, headers=self.headers,
@@ -396,7 +427,7 @@ def _add_seg_tag(data):
:param data: list of ([word], [pos], [heads], [head_tags]) :param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos]) :return: list of ([word], [pos])
""" """
_processed = [] _processed = []
for word_list, pos_list, _, _ in data: for word_list, pos_list, _, _ in data:
new_sample = [] new_sample = []
@@ -410,4 +441,3 @@ def _add_seg_tag(data):
new_sample.append((word[-1], 'E-' + pos)) new_sample.append((word[-1], 'E-' + pos))
_processed.append(list(map(list, zip(*new_sample)))) _processed.append(list(map(list, zip(*new_sample))))
return _processed return _processed


+ 16
- 12
fastNLP/io/embed_loader.py View File

@@ -7,13 +7,17 @@ import os


import numpy as np import numpy as np


from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import BaseLoader
from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader


import warnings import warnings


class EmbedLoader(BaseLoader): class EmbedLoader(BaseLoader):
"""这个类用于从预训练的Embedding中load数据。"""
"""
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader`

这个类用于从预训练的Embedding中load数据。
"""


def __init__(self): def __init__(self):
super(EmbedLoader, self).__init__() super(EmbedLoader, self).__init__()
@@ -25,13 +29,13 @@ class EmbedLoader(BaseLoader):
word2vec(第一行只有两个元素)还是glove格式的数据。 word2vec(第一行只有两个元素)还是glove格式的数据。


:param str embed_filepath: 预训练的embedding的路径。 :param str embed_filepath: 预训练的embedding的路径。
:param Vocabulary vocab: 词表,读取出现在vocab中的词的embedding。没有出现在vocab中的词的embedding将通过找到的词的
embedding的正态分布采样出来,以使得整个Embedding是同分布的。
:param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。
没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。
:param dtype: 读出的embedding的类型 :param dtype: 读出的embedding的类型
:param bool normalize: 是否将每个vector归一化到norm为1 :param bool normalize: 是否将每个vector归一化到norm为1
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地
方在于词表有空行或者词表出现了维度不一致。
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
""" """
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
if not os.path.exists(embed_filepath): if not os.path.exists(embed_filepath):
@@ -87,11 +91,11 @@ class EmbedLoader(BaseLoader):
:param str padding: the padding tag for vocabulary. :param str padding: the padding tag for vocabulary.
:param str unknown: the unknown tag for vocabulary. :param str unknown: the unknown tag for vocabulary.
:param bool normalize: 是否将每个vector归一化到norm为1 :param bool normalize: 是否将每个vector归一化到norm为1
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地
方在于词表有空行或者词表出现了维度不一致。 方在于词表有空行或者词表出现了维度不一致。
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
:return: numpy.ndarray,Vocabulary embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与
是否使用padding, 以及unknown有没有在词表中找到对应的词。Vocabulary中的词的顺序与Embedding的顺序是一一对应的。
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
:return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与
是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。
""" """
vocab = Vocabulary(padding=padding, unknown=unknown) vocab = Vocabulary(padding=padding, unknown=unknown)
vec_dict = {} vec_dict = {}


+ 3
- 0
fastNLP/io/file_reader.py View File

@@ -1,3 +1,6 @@
"""
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API
"""
import json import json






+ 21
- 14
fastNLP/io/model_io.py View File

@@ -1,16 +1,16 @@
""" """
.. _model-io:

用于载入和保存模型 用于载入和保存模型
""" """
import torch import torch


from fastNLP.io.base_loader import BaseLoader
from .base_loader import BaseLoader




class ModelLoader(BaseLoader): class ModelLoader(BaseLoader):
""" """
Loader for models.
别名::class:`fastNLP.io.ModelLoader` :class:`fastNLP.io.model_io.ModelLoader`

用于读取模型
""" """


def __init__(self): def __init__(self):
@@ -18,24 +18,30 @@ class ModelLoader(BaseLoader):


@staticmethod @staticmethod
def load_pytorch(empty_model, model_path): def load_pytorch(empty_model, model_path):
"""Load model parameters from ".pkl" files into the empty PyTorch model.
"""
从 ".pkl" 文件读取 PyTorch 模型


:param empty_model: a PyTorch model with initialized parameters.
:param str model_path: the path to the saved model.
:param empty_model: 初始化参数的 PyTorch 模型
:param str model_path: 模型保存的路径
""" """
empty_model.load_state_dict(torch.load(model_path)) empty_model.load_state_dict(torch.load(model_path))


@staticmethod @staticmethod
def load_pytorch_model(model_path): def load_pytorch_model(model_path):
"""Load the entire model.
"""
读取整个模型


:param str model_path: the path to the saved model.
:param str model_path: 模型保存的路径
""" """
return torch.load(model_path) return torch.load(model_path)




class ModelSaver(object): class ModelSaver(object):
"""Save a model
"""
别名::class:`fastNLP.io.ModelSaver` :class:`fastNLP.io.model_io.ModelSaver`

用于保存模型
Example:: Example::


saver = ModelSaver("./save/model_ckpt_100.pkl") saver = ModelSaver("./save/model_ckpt_100.pkl")
@@ -46,15 +52,16 @@ class ModelSaver(object):
def __init__(self, save_path): def __init__(self, save_path):
""" """


:param save_path: the path to the saving directory.
:param save_path: 模型保存的路径
""" """
self.save_path = save_path self.save_path = save_path


def save_pytorch(self, model, param_only=True): def save_pytorch(self, model, param_only=True):
"""Save a pytorch model into ".pkl" file.
"""
把 PyTorch 模型存入 ".pkl" 文件


:param model: a PyTorch model
:param bool param_only: whether only to save the model parameters or the entire model.
:param model: PyTorch 模型
:param bool param_only: 是否只保存模型的参数(否则保存整个模型)


""" """
if param_only is True: if param_only is True:


Loading…
Cancel
Save