@@ -291,6 +291,7 @@ import _pickle as pickle | |||||
from copy import deepcopy | from copy import deepcopy | ||||
import numpy as np | import numpy as np | ||||
from prettytable import PrettyTable | |||||
from ._logger import logger | from ._logger import logger | ||||
from .const import Const | from .const import Const | ||||
@@ -301,7 +302,6 @@ from .field import SetInputOrTargetException | |||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import pretty_table_printer | from .utils import pretty_table_printer | ||||
from prettytable import PrettyTable | |||||
class DataSet(object): | class DataSet(object): | ||||
@@ -428,23 +428,22 @@ class DataSet(object): | |||||
def print_field_meta(self): | def print_field_meta(self): | ||||
""" | """ | ||||
输出当前field的meta信息, 形似下列的输出 | |||||
+-------------+-------+-------+ | |||||
| field_names | x | y | | |||||
+-------------+-------+-------+ | |||||
| is_input | True | False | | |||||
| is_target | False | False | | |||||
| ignore_type | False | | | |||||
| pad_value | 0 | | | |||||
+-------------+-------+-------+ | |||||
field_names: DataSet中field的名称 | |||||
is_input: field是否为input | |||||
is_target: field是否为target | |||||
ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||||
pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||||
输出当前field的meta信息, 形似下列的输出:: | |||||
+-------------+-------+-------+ | |||||
| field_names | x | y | | |||||
+=============+=======+=======+ | |||||
| is_input | True | False | | |||||
| is_target | False | False | | |||||
| ignore_type | False | | | |||||
| pad_value | 0 | | | |||||
+-------------+-------+-------+ | |||||
:param field_names: DataSet中field的名称 | |||||
:param is_input: field是否为input | |||||
:param is_target: field是否为target | |||||
:param ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||||
:param pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||||
:return: | :return: | ||||
""" | """ | ||||
if len(self.field_arrays)>0: | if len(self.field_arrays)>0: | ||||
@@ -37,6 +37,7 @@ class Optimizer(object): | |||||
def _get_require_grads_param(self, params): | def _get_require_grads_param(self, params): | ||||
""" | """ | ||||
将params中不需要gradient的删除 | 将params中不需要gradient的删除 | ||||
:param iterable params: parameters | :param iterable params: parameters | ||||
:return: list(nn.Parameters) | :return: list(nn.Parameters) | ||||
""" | """ | ||||
@@ -85,7 +86,7 @@ class SGD(Optimizer): | |||||
class Adam(Optimizer): | class Adam(Optimizer): | ||||
""" | """ | ||||
Adam | |||||
""" | """ | ||||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | ||||
@@ -7,19 +7,19 @@ __all__ = [ | |||||
"StaticEmbedding" | "StaticEmbedding" | ||||
] | ] | ||||
import os | import os | ||||
import warnings | |||||
from collections import defaultdict | |||||
from copy import deepcopy | |||||
import numpy as np | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import numpy as np | |||||
import warnings | |||||
from .embedding import TokenEmbedding | |||||
from ..core import logger | |||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | ||||
from .embedding import TokenEmbedding | |||||
from ..modules.utils import _get_file_name_base_on_postfix | from ..modules.utils import _get_file_name_base_on_postfix | ||||
from copy import deepcopy | |||||
from collections import defaultdict | |||||
from ..core import logger | |||||
class StaticEmbedding(TokenEmbedding): | class StaticEmbedding(TokenEmbedding): | ||||
@@ -62,8 +62,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | ||||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | ||||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | ||||
:param dict **kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中 | |||||
找到的词语使用normalize。 | |||||
:param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。 | |||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | ||||
@@ -31,27 +31,27 @@ class Loader: | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | ||||
""" | |||||
r""" | |||||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | ||||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 | |||||
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式: | |||||
0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||||
名包含'train'、 'dev'、 'test'则会报错:: | |||||
1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错:: | |||||
data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||||
data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train | |||||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | # dev、 test等有所变化,可以通过以下的方式取出DataSet | ||||
tr_data = data_bundle.get_dataset('train') | tr_data = data_bundle.get_dataset('train') | ||||
te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 | te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 | ||||
(2) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||||
2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | ||||
data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | ||||
dev_data = data_bundle.get_dataset('dev') | dev_data = data_bundle.get_dataset('dev') | ||||
(3) 传入文件路径:: | |||||
3.传入文件路径:: | |||||
data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | ||||
tr_data = data_bundle.get_dataset('train') # 取出DataSet | tr_data = data_bundle.get_dataset('train') # 取出DataSet | ||||
@@ -28,12 +28,14 @@ class _NERPipe(Pipe): | |||||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | ||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | def __init__(self, encoding_type: str = 'bio', lower: bool = False): | ||||
""" | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||||
""" | |||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
@@ -51,9 +53,8 @@ class _NERPipe(Pipe): | |||||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | ||||
"[...]", "[...]" | "[...]", "[...]" | ||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||||
在传入DataBundle基础上原位修改。 | |||||
:return: DataBundle | |||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]在传入DataBundle基础上原位修改。 | |||||
:return DataBundle: | |||||
""" | """ | ||||
# 转换tag | # 转换tag | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
@@ -253,8 +254,7 @@ class _CNNERPipe(Pipe): | |||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], | raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], | ||||
是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | 是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field | |||||
的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
# 转换tag | # 转换tag | ||||
@@ -24,7 +24,7 @@ class Pipe: | |||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
""" | """ | ||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 `fastNLP.io.loader.Loader.load()` | |||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||||
:param paths: | :param paths: | ||||
:return: DataBundle | :return: DataBundle | ||||