@@ -11,7 +11,7 @@ import torch | |||
try: | |||
from tqdm.autonotebook import tqdm | |||
except: | |||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.callback import CallbackException | |||
@@ -115,7 +115,7 @@ class ENASTrainer(fastNLP.Trainer): | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
@@ -2,21 +2,21 @@ | |||
.. _Callback: | |||
""" | |||
Callback是fastNLP中被设计用于增强 Trainer_ 的类。如果Callback被传递给了 Trainer_ , 则 Trainer_ 会在对应的阶段调用Callback | |||
的函数,具体调用时机可以通过 Trainer_ 查看。 | |||
""" | |||
import os | |||
import torch | |||
from tensorboardX import SummaryWriter | |||
from fastNLP.io.model_io import ModelSaver, ModelLoader | |||
try: | |||
from tensorboardX import SummaryWriter | |||
except: | |||
pass | |||
class Callback(object): | |||
"""An Interface for all callbacks. | |||
Any customized callback should implement at least one of the following methods. | |||
"""这是Callback的基类,所有的callback必须继承自这个类。 | |||
""" | |||
@@ -26,93 +26,150 @@ class Callback(object): | |||
@property | |||
def trainer(self): | |||
""" | |||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | |||
:return: | |||
""" | |||
return self._trainer | |||
@property | |||
def step(self): | |||
"""current step number, in range(1, self.n_steps+1)""" | |||
"""当前运行到的step, 范围为[1, self.n_steps+1)""" | |||
return self._trainer.step | |||
@property | |||
def n_steps(self): | |||
"""total number of steps for training""" | |||
"""Trainer一共会运行多少步""" | |||
return self._trainer.n_steps | |||
@property | |||
def batch_size(self): | |||
"""batch size for training""" | |||
"""train和evaluate时的batch_size为多大""" | |||
return self._trainer.batch_size | |||
@property | |||
def epoch(self): | |||
"""current epoch number, in range(1, self.n_epochs+1)""" | |||
"""当前运行的epoch数,范围是[1, self.n_epochs+1)""" | |||
return self._trainer.epoch | |||
@property | |||
def n_epochs(self): | |||
"""total number of epochs""" | |||
"""一共会运行多少个epoch""" | |||
return self._trainer.n_epochs | |||
@property | |||
def optimizer(self): | |||
"""torch.optim.Optimizer for current model""" | |||
"""初始化Trainer时传递的Optimizer""" | |||
return self._trainer.optimizer | |||
@property | |||
def model(self): | |||
"""training model""" | |||
"""正在被Trainer训练的模型""" | |||
return self._trainer.model | |||
@property | |||
def pbar(self): | |||
"""If use_tqdm, return trainer's tqdm print bar, else return None.""" | |||
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。""" | |||
return self._trainer.pbar | |||
@property | |||
def update_every(self): | |||
"""The model in trainer will update parameters every `update_every` batches.""" | |||
"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | |||
return self._trainer.update_every | |||
@property | |||
def batch_per_epoch(self): | |||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||
return self._trainer.batch_per_epoch | |||
def on_train_begin(self): | |||
# before the main training loop | |||
""" | |||
在Train过程开始之前调用。 | |||
:return: | |||
""" | |||
pass | |||
def on_epoch_begin(self): | |||
# at the beginning of each epoch | |||
""" | |||
在每个epoch开始之前调用一次 | |||
:return: | |||
""" | |||
pass | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
# at the beginning of each step/mini-batch | |||
""" | |||
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | |||
可以进行一些负采样之类的操作 | |||
:param dict batch_x: DataSet中被设置为input的field的batch。 | |||
:param dict batch_y: DataSet中被设置为target的field的batch。 | |||
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 | |||
情况下可以帮助定位是哪个Sample导致了错误。 | |||
:return: | |||
""" | |||
pass | |||
def on_loss_begin(self, batch_y, predict_y): | |||
# after data_forward, and before loss computation | |||
""" | |||
在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。 | |||
:param dict batch_y: 在DataSet中被设置为target的field的batch集合。 | |||
:param dict predict_y: 模型的forward()返回的结果。 | |||
:return: | |||
""" | |||
pass | |||
def on_backward_begin(self, loss): | |||
# after loss computation, and before gradient backward | |||
""" | |||
在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。 | |||
:param torch.Tensor loss: 计算得到的loss值 | |||
:return: | |||
""" | |||
pass | |||
def on_backward_end(self): | |||
""" | |||
反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。 | |||
:return: | |||
""" | |||
pass | |||
def on_step_end(self): | |||
""" | |||
到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。 | |||
:return: | |||
""" | |||
pass | |||
def on_batch_end(self, *args): | |||
# at the end of each step/mini-batch | |||
def on_batch_end(self): | |||
""" | |||
这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。 | |||
""" | |||
pass | |||
def on_valid_begin(self): | |||
""" | |||
如果Trainer中设置了验证,则发生验证前会调用该函数 | |||
:return: | |||
""" | |||
pass | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
""" | |||
每次执行验证机的evaluation后会调用。传入eval_result | |||
每次执行验证集的evaluation后会调用。 | |||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | |||
:param metric_key: str | |||
:param optimizer: optimizer passed to trainer | |||
:param is_better_eval: bool, 当前dev结果是否比之前的好 | |||
:param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即 | |||
传入的dict是有两层,第一层是metric的名称,第二层是metric的具体指标。 | |||
:param str metric_key: 初始化Trainer时传入的metric_key。 | |||
:param torch.Optimizer optimizer: Trainer中使用的优化器。 | |||
:param bool is_better_eval: 当前dev结果是否比之前的好。 | |||
:return: | |||
""" | |||
pass | |||
@@ -137,7 +194,7 @@ class Callback(object): | |||
pass | |||
def transfer(func): | |||
def _transfer(func): | |||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||
:param func: | |||
:return: | |||
@@ -153,9 +210,7 @@ def transfer(func): | |||
class CallbackManager(Callback): | |||
"""A manager for all callbacks passed into Trainer. | |||
It collects resources inside Trainer and raise callbacks. | |||
"""内部使用的Callback管理类 | |||
""" | |||
def __init__(self, env, callbacks=None): | |||
@@ -182,104 +237,70 @@ class CallbackManager(Callback): | |||
for callback in self.callbacks: | |||
setattr(callback, '_'+env_name, env_val) # Callback.trainer | |||
@transfer | |||
@_transfer | |||
def on_train_begin(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_epoch_begin(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_loss_begin(self, batch_y, predict_y): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_backward_begin(self, loss): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_backward_end(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_step_end(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_batch_end(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_valid_begin(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_epoch_end(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_train_end(self): | |||
pass | |||
@transfer | |||
@_transfer | |||
def on_exception(self, exception): | |||
pass | |||
class DummyCallback(Callback): | |||
def on_train_begin(self, *arg): | |||
print(arg) | |||
def on_epoch_end(self): | |||
print(self.epoch, self.n_epochs) | |||
class EchoCallback(Callback): | |||
def on_train_begin(self): | |||
print("before_train") | |||
def on_epoch_begin(self): | |||
print("before_epoch") | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
print("before_batch") | |||
def on_loss_begin(self, batch_y, predict_y): | |||
print("before_loss") | |||
def on_backward_begin(self, loss): | |||
print("before_backward") | |||
def on_batch_end(self): | |||
print("after_batch") | |||
def on_epoch_end(self): | |||
print("after_epoch") | |||
def on_train_end(self): | |||
print("after_train") | |||
class GradientClipCallback(Callback): | |||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | |||
"""每次backward前,将parameter的gradient clip到某个范围。 | |||
:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer | |||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer | |||
的model中所有参数进行clip | |||
:param clip_value: float, 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||
:param clip_type: str, 支持'norm', 'value'两种。 | |||
(1) 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||
(2) 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于 | |||
clip_value的gradient被赋值为clip_value. | |||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||
:param str clip_type: 支持'norm', 'value'两种。 | |||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||
2. 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于 | |||
clip_value的gradient被赋值为clip_value. | |||
""" | |||
super().__init__() | |||
@@ -314,7 +335,7 @@ class EarlyStopCallback(Callback): | |||
def __init__(self, patience): | |||
""" | |||
:param int patience: 停止之前等待的epoch数 | |||
:param int patience: 多少个epoch没有变好就停止训练 | |||
""" | |||
super(EarlyStopCallback, self).__init__() | |||
self.patience = patience | |||
@@ -341,7 +362,7 @@ class LRScheduler(Callback): | |||
def __init__(self, lr_scheduler): | |||
"""对PyTorch LR Scheduler的包装 | |||
:param lr_scheduler: PyTorch的lr_scheduler | |||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||
""" | |||
super(LRScheduler, self).__init__() | |||
import torch.optim | |||
@@ -358,7 +379,7 @@ class ControlC(Callback): | |||
def __init__(self, quit_all): | |||
""" | |||
:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
""" | |||
super(ControlC, self).__init__() | |||
if type(quit_all) != bool: | |||
@@ -389,16 +410,16 @@ class SmoothValue(object): | |||
class LRFinder(Callback): | |||
def __init__(self, n_batch, start_lr=1e-6, end_lr=10): | |||
def __init__(self, start_lr=1e-6, end_lr=10): | |||
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||
:param n_batch: 一个epoch内的iteration数 | |||
:param start_lr: 学习率下界 | |||
:param end_lr: 学习率上界 | |||
:param int n_batch: 一个epoch内的iteration数 | |||
:param float start_lr: 学习率下界 | |||
:param float end_lr: 学习率上界 | |||
""" | |||
super(LRFinder, self).__init__() | |||
self.start_lr, self.end_lr = start_lr, end_lr | |||
self.num_it = n_batch | |||
self.num_it = self.batch_per_epoch | |||
self.stop = False | |||
self.best_loss = 0. | |||
self.best_lr = None | |||
@@ -514,7 +535,3 @@ class TensorboardCallback(Callback): | |||
del self._summary_writer | |||
if __name__ == "__main__": | |||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | |||
manager.on_train_begin() | |||
# print(manager.after_epoch()) |
@@ -277,7 +277,7 @@ import warnings | |||
from fastNLP.core.fieldarray import AutoPadder | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _get_func_signature | |||
class DataSet(object): | |||
"""fastNLP的数据容器 | |||
@@ -642,7 +642,7 @@ class DataSet(object): | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
@@ -707,7 +707,7 @@ class DataSet(object): | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
@@ -1,5 +1,5 @@ | |||
""" | |||
FieldArray是 DataSet_ 中一列的存储方式 | |||
FieldArray是 DataSet_ 中一列的存储方式,原理部分请参考 DataSet_ 处 | |||
.. _FieldArray: | |||
@@ -11,41 +11,19 @@ from copy import deepcopy | |||
class FieldArray(object): | |||
"""``FieldArray`` is the collection of ``Instance``s of the same field. | |||
It is the basic element of ``DataSet`` class. | |||
:param str name: the name of the FieldArray | |||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | |||
:param bool is_target: If True, this FieldArray is used to compute loss. | |||
:param bool is_input: If True, this FieldArray is used to the model input. | |||
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||
fieldarray.set_pad_val()。 | |||
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型 | |||
则会进行padding;(3)其它情况不进行padder。 | |||
假设需要对English word中character进行padding,则需要使用其他的padder。 | |||
或ignore_type为True但是需要进行padding。 | |||
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. | |||
(default: False) | |||
""" | |||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | |||
"""DataSet在初始化时会有两类方法对FieldArray操作: | |||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||
1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]}) | |||
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | |||
然后后面的样本使用FieldArray.append进行添加。 | |||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 | |||
ignore_type用来控制是否进行类型检查,如果为True,则不检查。 | |||
"""FieldArray是用于保存 DataSet_ 中一个field的实体。 | |||
:param str name: FieldArray的名称 | |||
:param list,numpy.ndarray content: 列表的元素可以为list,int,float, | |||
:param bool is_target: 这个field是否是一个target field。 | |||
:param bool is_input: 这个field是否是一个input field。 | |||
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||
fieldarray.set_pad_val()。默认为None,即使用 AutoPadder_ 。 | |||
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, 就 | |||
可以设置为True。具体意义请参考 DataSet_ 。 | |||
""" | |||
self.name = name | |||
if isinstance(content, list): | |||
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | |||
@@ -211,10 +189,10 @@ class FieldArray(object): | |||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||
def append(self, val): | |||
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为 | |||
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。 | |||
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | |||
的内容是匹配的。 | |||
:param val: Any. | |||
:param Any val: 需要append的值。 | |||
""" | |||
if self.ignore_type is False: | |||
if isinstance(val, list): | |||
@@ -262,8 +240,8 @@ class FieldArray(object): | |||
def get(self, indices, pad=True): | |||
"""根据给定的indices返回内容 | |||
:param indices: (int, List[int]), 获取indices对应的内容。 | |||
:param pad: bool, 是否对返回的结果进行padding。仅对indices为List[int]时有效 | |||
:param int,list(int) indices:, 获取indices对应的内容。 | |||
:param bool pad: , 是否对返回的结果进行padding。仅对indices为List[int]时有效 | |||
:return: (single, List) | |||
""" | |||
if isinstance(indices, int): | |||
@@ -281,7 +259,7 @@ class FieldArray(object): | |||
""" | |||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | |||
:param padder: (None, Padder). 设置为None即删除padder. | |||
:param None,Padder padder:. 设置为None即删除padder。 | |||
:return: | |||
""" | |||
if padder is not None: | |||
@@ -293,7 +271,7 @@ class FieldArray(object): | |||
def set_pad_val(self, pad_val): | |||
"""修改padder的pad_val. | |||
:param pad_val: int。将该field的pad值设置为该值 | |||
:param int pad_val: 该field的pad值设置为该值。 | |||
:return: | |||
""" | |||
if self.padder is not None: | |||
@@ -312,8 +290,8 @@ class FieldArray(object): | |||
""" | |||
将other的属性复制给本FieldArray(other必须为FieldArray类型).属性包括 is_input, is_target, padder, ignore_type | |||
:param other: FieldArray | |||
:return: | |||
:param FieldArray other: 从哪个field拷贝属性 | |||
:return: FieldArray | |||
""" | |||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | |||
@@ -324,7 +302,7 @@ class FieldArray(object): | |||
return self | |||
def is_iterable(content): | |||
def _is_iterable(content): | |||
try: | |||
_ = (e for e in content) | |||
except TypeError: | |||
@@ -350,11 +328,10 @@ class Padder: | |||
""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
:param contents: List[element]。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
:param field_name: str, field的名称。 | |||
:param field_ele_dtype: (np.int64, np.float64, np.str, None), 该field的内层元素的类型。如果该field的ignore_type | |||
为True,该这个值为None。 | |||
:param str, field_name: field的名称。 | |||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||
:return: np.array([padded_element]) | |||
Example:: | |||
@@ -400,10 +377,10 @@ class AutoPadder(Padder): | |||
2 如果元素类型为(np.int64, np.float64), | |||
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding | |||
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding | |||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||
如果某个instance中field为[1, 2, 3],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | |||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | |||
""" | |||
def __init__(self, pad_val=0): | |||
@@ -427,7 +404,7 @@ class AutoPadder(Padder): | |||
return False | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
if not is_iterable(contents[0]): | |||
if not _is_iterable(contents[0]): | |||
array = np.array([content for content in contents], dtype=field_ele_dtype) | |||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||
max_len = max([len(content) for content in contents]) | |||
@@ -454,7 +431,7 @@ class EngChar2DPadder(Padder): | |||
Example:: | |||
from fastNLP import DataSet | |||
from fastNLP import EnChar2DPadder | |||
from fastNLP import EngChar2DPadder | |||
from fastNLP import Vocabulary | |||
dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']}) | |||
dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars') | |||
@@ -462,14 +439,15 @@ class EngChar2DPadder(Padder): | |||
vocab.from_dataset(dataset, field_name='chars') | |||
vocab.index_dataset(dataset, field_name='chars') | |||
dataset.set_input('chars') | |||
padder = EnChar2DPadder() | |||
padder = EngChar2DPadder() | |||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | |||
""" | |||
def __init__(self, pad_val=0, pad_length=0): | |||
""" | |||
:param pad_val: int, pad的位置使用该index | |||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度都pad或截 | |||
取到该长度. | |||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度 | |||
都pad或截取到该长度. | |||
""" | |||
super().__init__(pad_val=pad_val) | |||
@@ -494,7 +472,7 @@ class EngChar2DPadder(Padder): | |||
except: | |||
raise ValueError("Field:{} only has two dimensions.".format(field_name)) | |||
if is_iterable(value): | |||
if _is_iterable(value): | |||
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
@@ -3,34 +3,33 @@ Instance文档 | |||
.. _Instance: | |||
测试 | |||
Instance是fastNLP中对应于一个sample的类。一个sample可以认为是fastNLP中的一个Instance对象。一个具像化的表示类似与 DataSet_ | |||
出那个表中所展示的一行。 | |||
""" | |||
class Instance(object): | |||
"""An Instance is an example of data. | |||
Example:: | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
ins["field_1"] | |||
>>[1, 1, 1] | |||
ins.add_field("field_3", [3, 3, 3]) | |||
""" | |||
def __init__(self, **fields): | |||
""" | |||
"""Instance的初始化如下面的Example所示 | |||
Example:: | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
ins["field_1"] | |||
>>[1, 1, 1] | |||
ins.add_field("field_3", [3, 3, 3]) | |||
:param fields: 可能是一维或者二维的 list or np.array | |||
ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) | |||
""" | |||
self.fields = fields | |||
def add_field(self, field_name, field): | |||
"""Add a new field to the instance. | |||
"""向Instance中增加一个field | |||
:param field_name: str, the name of the field. | |||
:param str field_name: 新增field的名称 | |||
:param Any field: 新增field的内容 | |||
""" | |||
self.fields[field_name] = field | |||
@@ -12,12 +12,12 @@ from collections import defaultdict | |||
import torch | |||
import torch.nn.functional as F | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import CheckRes | |||
from fastNLP.core.utils import _CheckError | |||
from fastNLP.core.utils import _CheckRes | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _check_function_or_method | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _get_func_signature | |||
class LossBase(object): | |||
@@ -70,7 +70,7 @@ class LossBase(object): | |||
for func_param, input_param in self.param_map.items(): | |||
if func_param not in func_args: | |||
raise NameError( | |||
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the " | |||
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | |||
f"initialization parameters, or change its signature.") | |||
# evaluate should not have varargs. | |||
@@ -111,7 +111,7 @@ class LossBase(object): | |||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||
for func_arg, input_arg in self.param_map.items(): | |||
if func_arg not in func_args: | |||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.") | |||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | |||
# 2. only part of the param_map are passed, left are not | |||
for arg in func_args: | |||
@@ -151,16 +151,16 @@ class LossBase(object): | |||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
f"in `{self.__class__.__name__}`)" | |||
check_res = CheckRes(missing=replaced_missing, | |||
unused=check_res.unused, | |||
duplicated=duplicated, | |||
required=check_res.required, | |||
all_needed=check_res.all_needed, | |||
varargs=check_res.varargs) | |||
check_res = _CheckRes(missing=replaced_missing, | |||
unused=check_res.unused, | |||
duplicated=duplicated, | |||
required=check_res.required, | |||
all_needed=check_res.all_needed, | |||
varargs=check_res.varargs) | |||
if check_res.missing or check_res.duplicated: | |||
raise CheckError(check_res=check_res, | |||
func_signature=get_func_signature(self.get_loss)) | |||
raise _CheckError(check_res=check_res, | |||
func_signature=_get_func_signature(self.get_loss)) | |||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | |||
loss = self.get_loss(**refined_args) | |||
@@ -289,14 +289,14 @@ class LossInForward(LossBase): | |||
def get_loss(self, **kwargs): | |||
if self.loss_key not in kwargs: | |||
check_res = CheckRes( | |||
check_res = _CheckRes( | |||
missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"], | |||
unused=[], | |||
duplicated=[], | |||
required=[], | |||
all_needed=[], | |||
varargs=[]) | |||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||
raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss)) | |||
return kwargs[self.loss_key] | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
@@ -13,11 +13,11 @@ from collections import defaultdict | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import CheckRes | |||
from fastNLP.core.utils import _CheckError | |||
from fastNLP.core.utils import _CheckRes | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _get_func_signature | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from fastNLP.core.vocabulary import Vocabulary | |||
@@ -161,7 +161,7 @@ class MetricBase(object): | |||
for func_param, input_param in self.param_map.items(): | |||
if func_param not in func_args: | |||
raise NameError( | |||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | |||
f"initialization parameters, or change its signature.") | |||
def _fast_param_map(self, pred_dict, target_dict): | |||
@@ -207,7 +207,7 @@ class MetricBase(object): | |||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||
for func_arg, input_arg in self.param_map.items(): | |||
if func_arg not in func_args: | |||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | |||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | |||
# 2. only part of the param_map are passed, left are not | |||
for arg in func_args: | |||
@@ -248,16 +248,16 @@ class MetricBase(object): | |||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
f"in `{self.__class__.__name__}`)" | |||
check_res = CheckRes(missing=replaced_missing, | |||
unused=check_res.unused, | |||
duplicated=duplicated, | |||
required=check_res.required, | |||
all_needed=check_res.all_needed, | |||
varargs=check_res.varargs) | |||
check_res = _CheckRes(missing=replaced_missing, | |||
unused=check_res.unused, | |||
duplicated=duplicated, | |||
required=check_res.required, | |||
all_needed=check_res.all_needed, | |||
varargs=check_res.varargs) | |||
if check_res.missing or check_res.duplicated: | |||
raise CheckError(check_res=check_res, | |||
func_signature=get_func_signature(self.evaluate)) | |||
raise _CheckError(check_res=check_res, | |||
func_signature=_get_func_signature(self.evaluate)) | |||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | |||
self.evaluate(**refined_args) | |||
@@ -294,14 +294,14 @@ class AccuracyMetric(MetricBase): | |||
""" | |||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
if not isinstance(target, torch.Tensor): | |||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if seq_len is not None and not isinstance(seq_len, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_lens)}.") | |||
if seq_len is not None: | |||
@@ -314,7 +314,7 @@ class AccuracyMetric(MetricBase): | |||
elif len(pred.size()) == len(target.size()) + 1: | |||
pred = pred.argmax(dim=-1) | |||
else: | |||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
f"{pred.size()[:-1]}, got {target.size()}.") | |||
@@ -516,14 +516,14 @@ class SpanFPreRecMetric(MetricBase): | |||
:return: | |||
""" | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
if not isinstance(target, torch.Tensor): | |||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if not isinstance(seq_len, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_len)}.") | |||
if pred.size() == target.size() and len(target.size()) == 2: | |||
@@ -535,7 +535,7 @@ class SpanFPreRecMetric(MetricBase): | |||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
"id >= {}, the number of classes.".format(num_classes)) | |||
else: | |||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
f"{pred.size()[:-1]}, got {target.size()}.") | |||
@@ -714,14 +714,14 @@ class BMESF1PreRecMetric(MetricBase): | |||
:return: | |||
""" | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
if not isinstance(target, torch.Tensor): | |||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if not isinstance(seq_len, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_len)}.") | |||
if pred.size() == target.size() and len(target.size()) == 2: | |||
@@ -729,7 +729,7 @@ class BMESF1PreRecMetric(MetricBase): | |||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
pred = pred.argmax(dim=-1) | |||
else: | |||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
f"{pred.size()[:-1]}, got {target.size()}.") | |||
@@ -5,11 +5,11 @@ from fastNLP.core.batch import Batch | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.metrics import _prepare_metrics | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _CheckError | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_loss_evaluate | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _get_func_signature | |||
from fastNLP.core.utils import _get_device | |||
@@ -75,19 +75,19 @@ class Tester(object): | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " | |||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||
f"must be `dict`, got {type(pred_dict)}.") | |||
for metric in self.metrics: | |||
metric(pred_dict, batch_y) | |||
for metric in self.metrics: | |||
eval_result = metric.get_metric() | |||
if not isinstance(eval_result, dict): | |||
raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be " | |||
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " | |||
f"`dict`, got {type(eval_result)}") | |||
metric_name = metric.__class__.__name__ | |||
eval_results[metric_name] = eval_result | |||
except CheckError as e: | |||
prev_func_signature = get_func_signature(self._predict_func) | |||
except _CheckError as e: | |||
prev_func_signature = _get_func_signature(self._predict_func) | |||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||
dataset=self.data, check_level=0) | |||
@@ -85,17 +85,17 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||
对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_ | |||
使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。 | |||
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 | |||
LossInForward_ 这个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor, | |||
并使用它作为loss。 如果Trainer初始化没有提供loss则使用这个loss TODO 补充一个例子 | |||
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 LossInForward_ 这 | |||
个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor, | |||
并使用它作为loss。 如果Trainer初始化没有提供loss则默认使用 LossInForward_ 。详细例子可以参照 TODO 补充一个例子 | |||
3. Metric | |||
Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。 | |||
在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法, | |||
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中的input的选择 | |||
出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子 | |||
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中被设置为input | |||
的field中选择出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子 | |||
2. Trainer的代码检查 | |||
@@ -112,12 +112,12 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||
from fastNLP import DataSet | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(1, 1) | |||
def forward(self, x, b): | |||
loss = torch.mean((self.fc(x)-b)**2) | |||
return {'loss': loss} | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(1, 1) | |||
def forward(self, x, b): | |||
loss = torch.mean((self.fc(x)-b)**2) | |||
return {'loss': loss} | |||
model = Model() | |||
dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) | |||
@@ -138,16 +138,15 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||
# unused field: ['a'] | |||
# Suggestion: You need to provide ['x'] in DataSet and set it as input. | |||
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里由两类 | |||
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类 | |||
信息可以为你提供参考 | |||
1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里 | |||
因为train dataset没有target所以没有显示。根据这里你可以看出是否正确将需要的内容设置为了input或target。 | |||
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。 | |||
2. 如果出现了映射错误,出现NameError。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断 | |||
出当前是在调取forward出错),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能 | |||
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x' | |||
,或者model的参数从'x'修改为'a'都可以解决问题。 | |||
2. NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断 | |||
出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能 | |||
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。 | |||
下面的例子是由于loss计算的时候找不到需要的值 | |||
@@ -249,18 +248,18 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||
# (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a). | |||
报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation | |||
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation的弄错的情况。这里的修改是通过在初始化metric的时候 | |||
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output'). | |||
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候 | |||
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。 | |||
可以通过check_code_level调节检查的强度。默认为0,即进行检查。 | |||
3. Trainer与callback | |||
虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。 | |||
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的类,所有的 Callback_ 都具有 | |||
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的函数集合,所有的 Callback_ 都具有 | |||
on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。 | |||
我们将Train.train()这个函数内部分为以下的阶段 | |||
我们将Train.train()这个函数内部分为以下的阶段,在对应阶段会触发相应的调用。 | |||
Example:: | |||
@@ -305,7 +304,7 @@ import warnings | |||
try: | |||
from tqdm.autonotebook import tqdm | |||
except: | |||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.callback import CallbackManager, CallbackException | |||
@@ -316,12 +315,12 @@ from fastNLP.core.sampler import Sampler | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.tester import Tester | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _CheckError | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_forward_error | |||
from fastNLP.core.utils import _check_loss_evaluate | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import _get_func_signature | |||
from fastNLP.core.utils import _get_device | |||
@@ -466,34 +465,11 @@ class Trainer(object): | |||
def train(self, load_best_model=True): | |||
""" | |||
开始训练过程。主要有以下几个步骤:: | |||
for epoch in range(num_epochs): | |||
# 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。 | |||
非float,int类型的参数将不会被转换为Tensor,且不进行padding。 | |||
for batch_x, batch_y in Batch(DataSet) | |||
# batch_x是一个dict, 被设为input的field会出现在这个dict中, | |||
key为DataSet中的field_name, value为该field的value | |||
# batch_y也是一个dict,被设为target的field会出现在这个dict中, | |||
key为DataSet中的field_name, value为该field的value | |||
2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形 | |||
参完成参数传递。例如, | |||
forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的 | |||
value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数 | |||
而且默认参数这个key没有在batch_x中,则使用默认参数。 | |||
3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况 | |||
中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或 | |||
target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计 | |||
算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target | |||
, 并完成loss的计算。 | |||
4. 获取到loss之后,进行反向求导并更新梯度 | |||
根据需要适时进行验证机测试 | |||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||
使用该函数使Trainer开始训练。 | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, | |||
:return dict: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
@@ -547,7 +523,7 @@ class Trainer(object): | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
@@ -559,6 +535,7 @@ class Trainer(object): | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
prefetch=self.prefetch) | |||
self.batch_per_epoch = data_iterator.num_batches | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.epoch = epoch | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
@@ -660,7 +637,7 @@ class Trainer(object): | |||
x = _build_args(network.forward, **x) | |||
y = network(**x) | |||
if not isinstance(y, dict): | |||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||
raise TypeError(f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||
return y | |||
def _grad_backward(self, loss): | |||
@@ -796,7 +773,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
pred_dict = model(**refined_batch_x) | |||
func_signature = get_func_signature(model.forward) | |||
func_signature = _get_func_signature(model.forward) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | |||
@@ -807,16 +784,16 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
if batch_count == 0: | |||
if not isinstance(loss, torch.Tensor): | |||
raise TypeError( | |||
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||
f"The return value of {_get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||
f"but got `{type(loss)}`.") | |||
if len(loss.size()) != 0: | |||
raise ValueError( | |||
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " | |||
f"The size of return value of {_get_func_signature(losser.get_loss)} is {loss.size()}, " | |||
f"should be torch.size([])") | |||
loss.backward() | |||
except CheckError as e: | |||
# TODO: another error raised if CheckError caught | |||
pre_func_signature = get_func_signature(model.forward) | |||
except _CheckError as e: | |||
# TODO: another error raised if _CheckError caught | |||
pre_func_signature = _get_func_signature(model.forward) | |||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | |||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||
dataset=dataset, check_level=check_level) | |||
@@ -9,7 +9,7 @@ import numpy as np | |||
import torch | |||
from torch import nn | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
def _prepare_cache_filepath(filepath): | |||
@@ -28,6 +28,57 @@ def _prepare_cache_filepath(filepath): | |||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
""" | |||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用 | |||
Example:: | |||
import time | |||
import numpy as np | |||
from fastNLP import cache_results | |||
@cache_results('cache.pkl') | |||
def process_data(): | |||
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | |||
time.sleep(1) | |||
return np.random.randint(5, size=(10, 20)) | |||
start_time = time.time() | |||
process_data() | |||
print(time.time() - start_time) | |||
start_time = time.time() | |||
process_data() | |||
print(time.time() - start_time) | |||
# 输出内容如下 | |||
# Save cache to cache.pkl. | |||
# 1.0015439987182617 | |||
# Read cache from cache.pkl. | |||
# 0.00013065338134765625 | |||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理 | |||
Example:: | |||
# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | |||
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' | |||
上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | |||
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | |||
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称。 | |||
Example:: | |||
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 | |||
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | |||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | |||
函数调用的时候传入_cache_fp这个参数。 | |||
:param bool _refresh: 是否重新生成cache。 | |||
:param int _verbose: 是否打印cache的信息。 | |||
:return: | |||
""" | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
@@ -74,48 +125,48 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
return wrapper | |||
return wrapper_ | |||
def save_pickle(obj, pickle_path, file_name): | |||
"""Save an object into a pickle file. | |||
:param obj: an object | |||
:param pickle_path: str, the directory where the pickle file is to be saved | |||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||
""" | |||
if not os.path.exists(pickle_path): | |||
os.mkdir(pickle_path) | |||
print("make dir {} before saving pickle file".format(pickle_path)) | |||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||
_pickle.dump(obj, f) | |||
print("{} saved in {}".format(file_name, pickle_path)) | |||
def load_pickle(pickle_path, file_name): | |||
"""Load an object from a given pickle file. | |||
:param pickle_path: str, the directory where the pickle file is. | |||
:param file_name: str, the name of the pickle file. | |||
:return obj: an object stored in the pickle | |||
""" | |||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||
obj = _pickle.load(f) | |||
print("{} loaded from {}".format(file_name, pickle_path)) | |||
return obj | |||
def pickle_exist(pickle_path, pickle_name): | |||
"""Check if a given pickle file exists in the directory. | |||
:param pickle_path: the directory of target pickle file | |||
:param pickle_name: the filename of target pickle file | |||
:return: True if file exists else False | |||
""" | |||
if not os.path.exists(pickle_path): | |||
os.makedirs(pickle_path) | |||
file_name = os.path.join(pickle_path, pickle_name) | |||
if os.path.exists(file_name): | |||
return True | |||
else: | |||
return False | |||
# def save_pickle(obj, pickle_path, file_name): | |||
# """Save an object into a pickle file. | |||
# | |||
# :param obj: an object | |||
# :param pickle_path: str, the directory where the pickle file is to be saved | |||
# :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||
# """ | |||
# if not os.path.exists(pickle_path): | |||
# os.mkdir(pickle_path) | |||
# print("make dir {} before saving pickle file".format(pickle_path)) | |||
# with open(os.path.join(pickle_path, file_name), "wb") as f: | |||
# _pickle.dump(obj, f) | |||
# print("{} saved in {}".format(file_name, pickle_path)) | |||
# | |||
# | |||
# def load_pickle(pickle_path, file_name): | |||
# """Load an object from a given pickle file. | |||
# | |||
# :param pickle_path: str, the directory where the pickle file is. | |||
# :param file_name: str, the name of the pickle file. | |||
# :return obj: an object stored in the pickle | |||
# """ | |||
# with open(os.path.join(pickle_path, file_name), "rb") as f: | |||
# obj = _pickle.load(f) | |||
# print("{} loaded from {}".format(file_name, pickle_path)) | |||
# return obj | |||
# | |||
# | |||
# def pickle_exist(pickle_path, pickle_name): | |||
# """Check if a given pickle file exists in the directory. | |||
# | |||
# :param pickle_path: the directory of target pickle file | |||
# :param pickle_name: the filename of target pickle file | |||
# :return: True if file exists else False | |||
# """ | |||
# if not os.path.exists(pickle_path): | |||
# os.makedirs(pickle_path) | |||
# file_name = os.path.join(pickle_path, pickle_name) | |||
# if os.path.exists(file_name): | |||
# return True | |||
# else: | |||
# return False | |||
def _get_device(device, check_exist=False): | |||
""" | |||
@@ -232,15 +283,15 @@ def _check_arg_dict_list(func, args): | |||
missing = list(require_args - input_args) | |||
unused = list(input_args - all_args) | |||
varargs = [] if not spect.varargs else [spect.varargs] | |||
return CheckRes(missing=missing, | |||
unused=unused, | |||
duplicated=duplicated, | |||
required=list(require_args), | |||
all_needed=list(all_args), | |||
varargs=varargs) | |||
return _CheckRes(missing=missing, | |||
unused=unused, | |||
duplicated=duplicated, | |||
required=list(require_args), | |||
all_needed=list(all_args), | |||
varargs=varargs) | |||
def get_func_signature(func): | |||
def _get_func_signature(func): | |||
""" | |||
Given a function or method, return its signature. | |||
@@ -318,13 +369,13 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||
raise TypeError("Only support `dict` type right now.") | |||
class CheckError(Exception): | |||
class _CheckError(Exception): | |||
""" | |||
CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
def __init__(self, check_res: CheckRes, func_signature: str): | |||
def __init__(self, check_res: _CheckRes, func_signature: str): | |||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||
if check_res.varargs: | |||
@@ -347,7 +398,7 @@ WARNING_CHECK_LEVEL = 1 | |||
STRICT_CHECK_LEVEL = 2 | |||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes, | |||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: _CheckRes, | |||
pred_dict: dict, target_dict: dict, dataset, check_level=0): | |||
errs = [] | |||
unuseds = [] | |||
@@ -449,7 +500,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||
func_signature = get_func_signature(forward_func) | |||
func_signature = _get_func_signature(forward_func) | |||
errs = [] | |||
suggestions = [] | |||
@@ -543,7 +594,7 @@ def seq_mask(seq_len, max_len): | |||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | |||
class pseudo_tqdm: | |||
class _pseudo_tqdm: | |||
""" | |||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||
""" | |||
@@ -2,7 +2,7 @@ from functools import wraps | |||
from collections import Counter | |||
from fastNLP.core.dataset import DataSet | |||
def check_build_vocab(func): | |||
def _check_build_vocab(func): | |||
"""A decorator to make sure the indexing is built before used. | |||
""" | |||
@@ -15,7 +15,7 @@ def check_build_vocab(func): | |||
return _wrapper | |||
def check_build_status(func): | |||
def _check_build_status(func): | |||
"""A decorator to check whether the vocabulary updates after the last build. | |||
""" | |||
@@ -67,7 +67,7 @@ class Vocabulary(object): | |||
self.idx2word = None | |||
self.rebuild = True | |||
@check_build_status | |||
@_check_build_status | |||
def update(self, word_lst): | |||
"""依次增加序列中词在词典中的出现频率 | |||
@@ -75,7 +75,7 @@ class Vocabulary(object): | |||
""" | |||
self.word_count.update(word_lst) | |||
@check_build_status | |||
@_check_build_status | |||
def add(self, word): | |||
""" | |||
增加一个新词在词典中的出现频率 | |||
@@ -84,7 +84,7 @@ class Vocabulary(object): | |||
""" | |||
self.word_count[word] += 1 | |||
@check_build_status | |||
@_check_build_status | |||
def add_word(self, word): | |||
""" | |||
增加一个新词在词典中的出现频率 | |||
@@ -93,7 +93,7 @@ class Vocabulary(object): | |||
""" | |||
self.add(word) | |||
@check_build_status | |||
@_check_build_status | |||
def add_word_lst(self, word_lst): | |||
""" | |||
依次增加序列中词在词典中的出现频率 | |||
@@ -132,11 +132,11 @@ class Vocabulary(object): | |||
""" | |||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
@check_build_vocab | |||
@_check_build_vocab | |||
def __len__(self): | |||
return len(self.word2idx) | |||
@check_build_vocab | |||
@_check_build_vocab | |||
def __contains__(self, item): | |||
""" | |||
检查词是否被记录 | |||
@@ -161,7 +161,7 @@ class Vocabulary(object): | |||
""" | |||
return self.__contains__(w) | |||
@check_build_vocab | |||
@_check_build_vocab | |||
def __getitem__(self, w): | |||
""" | |||
To support usage like:: | |||
@@ -175,7 +175,7 @@ class Vocabulary(object): | |||
else: | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
@check_build_vocab | |||
@_check_build_vocab | |||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
""" | |||
将DataSet中对应field的词转为数字. | |||
@@ -275,7 +275,7 @@ class Vocabulary(object): | |||
return self.__getitem__(w) | |||
@property | |||
@check_build_vocab | |||
@_check_build_vocab | |||
def unknown_idx(self): | |||
""" | |||
unknown 对应的数字. | |||
@@ -285,7 +285,7 @@ class Vocabulary(object): | |||
return self.word2idx[self.unknown] | |||
@property | |||
@check_build_vocab | |||
@_check_build_vocab | |||
def padding_idx(self): | |||
""" | |||
padding 对应的数字 | |||
@@ -294,7 +294,7 @@ class Vocabulary(object): | |||
return None | |||
return self.word2idx[self.padding] | |||
@check_build_vocab | |||
@_check_build_vocab | |||
def to_word(self, idx): | |||
""" | |||
给定一个数字, 将其转为对应的词. | |||
@@ -13,12 +13,12 @@ from torch import nn | |||
try: | |||
from tqdm.autonotebook import tqdm | |||
except: | |||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.callback import CallbackManager, CallbackException | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _CheckError | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
import fastNLP | |||
import fastNLP.models.enas_utils as utils | |||
@@ -118,7 +118,7 @@ class ENASTrainer(fastNLP.Trainer): | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
@@ -8,7 +8,7 @@ import numpy as np | |||
import torch.nn.functional as F | |||
from torch import nn | |||
import time | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _CheckError | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||