@@ -291,6 +291,7 @@ import _pickle as pickle | |||
from copy import deepcopy | |||
import numpy as np | |||
from prettytable import PrettyTable | |||
from ._logger import logger | |||
from .const import Const | |||
@@ -301,7 +302,6 @@ from .field import SetInputOrTargetException | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .utils import pretty_table_printer | |||
from prettytable import PrettyTable | |||
class DataSet(object): | |||
@@ -428,23 +428,22 @@ class DataSet(object): | |||
def print_field_meta(self): | |||
""" | |||
输出当前field的meta信息, 形似下列的输出 | |||
+-------------+-------+-------+ | |||
| field_names | x | y | | |||
+-------------+-------+-------+ | |||
| is_input | True | False | | |||
| is_target | False | False | | |||
| ignore_type | False | | | |||
| pad_value | 0 | | | |||
+-------------+-------+-------+ | |||
field_names: DataSet中field的名称 | |||
is_input: field是否为input | |||
is_target: field是否为target | |||
ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||
pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||
输出当前field的meta信息, 形似下列的输出:: | |||
+-------------+-------+-------+ | |||
| field_names | x | y | | |||
+=============+=======+=======+ | |||
| is_input | True | False | | |||
| is_target | False | False | | |||
| ignore_type | False | | | |||
| pad_value | 0 | | | |||
+-------------+-------+-------+ | |||
:param field_names: DataSet中field的名称 | |||
:param is_input: field是否为input | |||
:param is_target: field是否为target | |||
:param ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||
:param pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||
:return: | |||
""" | |||
if len(self.field_arrays)>0: | |||
@@ -37,6 +37,7 @@ class Optimizer(object): | |||
def _get_require_grads_param(self, params): | |||
""" | |||
将params中不需要gradient的删除 | |||
:param iterable params: parameters | |||
:return: list(nn.Parameters) | |||
""" | |||
@@ -85,7 +86,7 @@ class SGD(Optimizer): | |||
class Adam(Optimizer): | |||
""" | |||
Adam | |||
""" | |||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | |||
@@ -7,19 +7,19 @@ __all__ = [ | |||
"StaticEmbedding" | |||
] | |||
import os | |||
import warnings | |||
from collections import defaultdict | |||
from copy import deepcopy | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
import warnings | |||
from .embedding import TokenEmbedding | |||
from ..core import logger | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||
from .embedding import TokenEmbedding | |||
from ..modules.utils import _get_file_name_base_on_postfix | |||
from copy import deepcopy | |||
from collections import defaultdict | |||
from ..core import logger | |||
class StaticEmbedding(TokenEmbedding): | |||
@@ -62,8 +62,7 @@ class StaticEmbedding(TokenEmbedding): | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
:param dict **kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中 | |||
找到的词语使用normalize。 | |||
:param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
@@ -31,27 +31,27 @@ class Loader: | |||
raise NotImplementedError | |||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
""" | |||
r""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 | |||
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式: | |||
0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||
名包含'train'、 'dev'、 'test'则会报错:: | |||
1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错:: | |||
data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||
data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train | |||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | |||
tr_data = data_bundle.get_dataset('train') | |||
te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 | |||
(2) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||
2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
dev_data = data_bundle.get_dataset('dev') | |||
(3) 传入文件路径:: | |||
3.传入文件路径:: | |||
data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
tr_data = data_bundle.get_dataset('train') # 取出DataSet | |||
@@ -28,12 +28,14 @@ class _NERPipe(Pipe): | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||
""" | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
""" | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
@@ -51,9 +53,8 @@ class _NERPipe(Pipe): | |||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]" | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||
在传入DataBundle基础上原位修改。 | |||
:return: DataBundle | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]在传入DataBundle基础上原位修改。 | |||
:return DataBundle: | |||
""" | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
@@ -253,8 +254,7 @@ class _CNNERPipe(Pipe): | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], | |||
是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field | |||
的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||
:return: DataBundle | |||
""" | |||
# 转换tag | |||
@@ -24,7 +24,7 @@ class Pipe: | |||
def process_from_file(self, paths) -> DataBundle: | |||
""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 `fastNLP.io.loader.Loader.load()` | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
:param paths: | |||
:return: DataBundle | |||