@@ -11,7 +11,7 @@ import torch | |||||
try: | try: | ||||
from tqdm.autonotebook import tqdm | from tqdm.autonotebook import tqdm | ||||
except: | except: | ||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackException | from fastNLP.core.callback import CallbackException | ||||
@@ -115,7 +115,7 @@ class ENASTrainer(fastNLP.Trainer): | |||||
def _train(self): | def _train(self): | ||||
if not self.use_tqdm: | if not self.use_tqdm: | ||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||||
else: | else: | ||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
self.step = 0 | self.step = 0 | ||||
@@ -2,21 +2,21 @@ | |||||
.. _Callback: | .. _Callback: | ||||
""" | |||||
Callback是fastNLP中被设计用于增强 Trainer_ 的类。如果Callback被传递给了 Trainer_ , 则 Trainer_ 会在对应的阶段调用Callback | |||||
的函数,具体调用时机可以通过 Trainer_ 查看。 | |||||
""" | |||||
import os | import os | ||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | |||||
from fastNLP.io.model_io import ModelSaver, ModelLoader | from fastNLP.io.model_io import ModelSaver, ModelLoader | ||||
try: | |||||
from tensorboardX import SummaryWriter | |||||
except: | |||||
pass | |||||
class Callback(object): | class Callback(object): | ||||
"""An Interface for all callbacks. | |||||
Any customized callback should implement at least one of the following methods. | |||||
"""这是Callback的基类,所有的callback必须继承自这个类。 | |||||
""" | """ | ||||
@@ -26,93 +26,150 @@ class Callback(object): | |||||
@property | @property | ||||
def trainer(self): | def trainer(self): | ||||
""" | |||||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | |||||
:return: | |||||
""" | |||||
return self._trainer | return self._trainer | ||||
@property | @property | ||||
def step(self): | def step(self): | ||||
"""current step number, in range(1, self.n_steps+1)""" | |||||
"""当前运行到的step, 范围为[1, self.n_steps+1)""" | |||||
return self._trainer.step | return self._trainer.step | ||||
@property | @property | ||||
def n_steps(self): | def n_steps(self): | ||||
"""total number of steps for training""" | |||||
"""Trainer一共会运行多少步""" | |||||
return self._trainer.n_steps | return self._trainer.n_steps | ||||
@property | @property | ||||
def batch_size(self): | def batch_size(self): | ||||
"""batch size for training""" | |||||
"""train和evaluate时的batch_size为多大""" | |||||
return self._trainer.batch_size | return self._trainer.batch_size | ||||
@property | @property | ||||
def epoch(self): | def epoch(self): | ||||
"""current epoch number, in range(1, self.n_epochs+1)""" | |||||
"""当前运行的epoch数,范围是[1, self.n_epochs+1)""" | |||||
return self._trainer.epoch | return self._trainer.epoch | ||||
@property | @property | ||||
def n_epochs(self): | def n_epochs(self): | ||||
"""total number of epochs""" | |||||
"""一共会运行多少个epoch""" | |||||
return self._trainer.n_epochs | return self._trainer.n_epochs | ||||
@property | @property | ||||
def optimizer(self): | def optimizer(self): | ||||
"""torch.optim.Optimizer for current model""" | |||||
"""初始化Trainer时传递的Optimizer""" | |||||
return self._trainer.optimizer | return self._trainer.optimizer | ||||
@property | @property | ||||
def model(self): | def model(self): | ||||
"""training model""" | |||||
"""正在被Trainer训练的模型""" | |||||
return self._trainer.model | return self._trainer.model | ||||
@property | @property | ||||
def pbar(self): | def pbar(self): | ||||
"""If use_tqdm, return trainer's tqdm print bar, else return None.""" | |||||
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。""" | |||||
return self._trainer.pbar | return self._trainer.pbar | ||||
@property | @property | ||||
def update_every(self): | def update_every(self): | ||||
"""The model in trainer will update parameters every `update_every` batches.""" | |||||
"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | |||||
return self._trainer.update_every | return self._trainer.update_every | ||||
@property | |||||
def batch_per_epoch(self): | |||||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||||
return self._trainer.batch_per_epoch | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | |||||
""" | |||||
在Train过程开始之前调用。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
# at the beginning of each epoch | |||||
""" | |||||
在每个epoch开始之前调用一次 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
# at the beginning of each step/mini-batch | |||||
""" | |||||
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | |||||
可以进行一些负采样之类的操作 | |||||
:param dict batch_x: DataSet中被设置为input的field的batch。 | |||||
:param dict batch_y: DataSet中被设置为target的field的batch。 | |||||
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 | |||||
情况下可以帮助定位是哪个Sample导致了错误。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
# after data_forward, and before loss computation | |||||
""" | |||||
在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。 | |||||
:param dict batch_y: 在DataSet中被设置为target的field的batch集合。 | |||||
:param dict predict_y: 模型的forward()返回的结果。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
# after loss computation, and before gradient backward | |||||
""" | |||||
在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。 | |||||
:param torch.Tensor loss: 计算得到的loss值 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_backward_end(self): | def on_backward_end(self): | ||||
""" | |||||
反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_step_end(self): | def on_step_end(self): | ||||
""" | |||||
到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_batch_end(self, *args): | |||||
# at the end of each step/mini-batch | |||||
def on_batch_end(self): | |||||
""" | |||||
这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。 | |||||
""" | |||||
pass | pass | ||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
""" | |||||
如果Trainer中设置了验证,则发生验证前会调用该函数 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
""" | """ | ||||
每次执行验证机的evaluation后会调用。传入eval_result | |||||
每次执行验证集的evaluation后会调用。 | |||||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | |||||
:param metric_key: str | |||||
:param optimizer: optimizer passed to trainer | |||||
:param is_better_eval: bool, 当前dev结果是否比之前的好 | |||||
:param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即 | |||||
传入的dict是有两层,第一层是metric的名称,第二层是metric的具体指标。 | |||||
:param str metric_key: 初始化Trainer时传入的metric_key。 | |||||
:param torch.Optimizer optimizer: Trainer中使用的优化器。 | |||||
:param bool is_better_eval: 当前dev结果是否比之前的好。 | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
@@ -137,7 +194,7 @@ class Callback(object): | |||||
pass | pass | ||||
def transfer(func): | |||||
def _transfer(func): | |||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | """装饰器,将对CallbackManager的调用转发到各个Callback子类. | ||||
:param func: | :param func: | ||||
:return: | :return: | ||||
@@ -153,9 +210,7 @@ def transfer(func): | |||||
class CallbackManager(Callback): | class CallbackManager(Callback): | ||||
"""A manager for all callbacks passed into Trainer. | |||||
It collects resources inside Trainer and raise callbacks. | |||||
"""内部使用的Callback管理类 | |||||
""" | """ | ||||
def __init__(self, env, callbacks=None): | def __init__(self, env, callbacks=None): | ||||
@@ -182,104 +237,70 @@ class CallbackManager(Callback): | |||||
for callback in self.callbacks: | for callback in self.callbacks: | ||||
setattr(callback, '_'+env_name, env_val) # Callback.trainer | setattr(callback, '_'+env_name, env_val) # Callback.trainer | ||||
@transfer | |||||
@_transfer | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_backward_end(self): | def on_backward_end(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_step_end(self): | def on_step_end(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_batch_end(self): | def on_batch_end(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_train_end(self): | def on_train_end(self): | ||||
pass | pass | ||||
@transfer | |||||
@_transfer | |||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
pass | pass | ||||
class DummyCallback(Callback): | |||||
def on_train_begin(self, *arg): | |||||
print(arg) | |||||
def on_epoch_end(self): | |||||
print(self.epoch, self.n_epochs) | |||||
class EchoCallback(Callback): | |||||
def on_train_begin(self): | |||||
print("before_train") | |||||
def on_epoch_begin(self): | |||||
print("before_epoch") | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | |||||
print("before_batch") | |||||
def on_loss_begin(self, batch_y, predict_y): | |||||
print("before_loss") | |||||
def on_backward_begin(self, loss): | |||||
print("before_backward") | |||||
def on_batch_end(self): | |||||
print("after_batch") | |||||
def on_epoch_end(self): | |||||
print("after_epoch") | |||||
def on_train_end(self): | |||||
print("after_train") | |||||
class GradientClipCallback(Callback): | class GradientClipCallback(Callback): | ||||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | ||||
"""每次backward前,将parameter的gradient clip到某个范围。 | """每次backward前,将parameter的gradient clip到某个范围。 | ||||
:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer | |||||
的model中所有参数进行clip | 的model中所有参数进行clip | ||||
:param clip_value: float, 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param clip_type: str, 支持'norm', 'value'两种。 | |||||
(1) 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
(2) 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于 | |||||
clip_value的gradient被赋值为clip_value. | |||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param str clip_type: 支持'norm', 'value'两种。 | |||||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
2. 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于 | |||||
clip_value的gradient被赋值为clip_value. | |||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
@@ -314,7 +335,7 @@ class EarlyStopCallback(Callback): | |||||
def __init__(self, patience): | def __init__(self, patience): | ||||
""" | """ | ||||
:param int patience: 停止之前等待的epoch数 | |||||
:param int patience: 多少个epoch没有变好就停止训练 | |||||
""" | """ | ||||
super(EarlyStopCallback, self).__init__() | super(EarlyStopCallback, self).__init__() | ||||
self.patience = patience | self.patience = patience | ||||
@@ -341,7 +362,7 @@ class LRScheduler(Callback): | |||||
def __init__(self, lr_scheduler): | def __init__(self, lr_scheduler): | ||||
"""对PyTorch LR Scheduler的包装 | """对PyTorch LR Scheduler的包装 | ||||
:param lr_scheduler: PyTorch的lr_scheduler | |||||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||||
""" | """ | ||||
super(LRScheduler, self).__init__() | super(LRScheduler, self).__init__() | ||||
import torch.optim | import torch.optim | ||||
@@ -358,7 +379,7 @@ class ControlC(Callback): | |||||
def __init__(self, quit_all): | def __init__(self, quit_all): | ||||
""" | """ | ||||
:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||||
""" | """ | ||||
super(ControlC, self).__init__() | super(ControlC, self).__init__() | ||||
if type(quit_all) != bool: | if type(quit_all) != bool: | ||||
@@ -389,16 +410,16 @@ class SmoothValue(object): | |||||
class LRFinder(Callback): | class LRFinder(Callback): | ||||
def __init__(self, n_batch, start_lr=1e-6, end_lr=10): | |||||
def __init__(self, start_lr=1e-6, end_lr=10): | |||||
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | """用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | ||||
:param n_batch: 一个epoch内的iteration数 | |||||
:param start_lr: 学习率下界 | |||||
:param end_lr: 学习率上界 | |||||
:param int n_batch: 一个epoch内的iteration数 | |||||
:param float start_lr: 学习率下界 | |||||
:param float end_lr: 学习率上界 | |||||
""" | """ | ||||
super(LRFinder, self).__init__() | super(LRFinder, self).__init__() | ||||
self.start_lr, self.end_lr = start_lr, end_lr | self.start_lr, self.end_lr = start_lr, end_lr | ||||
self.num_it = n_batch | |||||
self.num_it = self.batch_per_epoch | |||||
self.stop = False | self.stop = False | ||||
self.best_loss = 0. | self.best_loss = 0. | ||||
self.best_lr = None | self.best_lr = None | ||||
@@ -514,7 +535,3 @@ class TensorboardCallback(Callback): | |||||
del self._summary_writer | del self._summary_writer | ||||
if __name__ == "__main__": | |||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | |||||
manager.on_train_begin() | |||||
# print(manager.after_epoch()) |
@@ -277,7 +277,7 @@ import warnings | |||||
from fastNLP.core.fieldarray import AutoPadder | from fastNLP.core.fieldarray import AutoPadder | ||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _get_func_signature | |||||
class DataSet(object): | class DataSet(object): | ||||
"""fastNLP的数据容器 | """fastNLP的数据容器 | ||||
@@ -642,7 +642,7 @@ class DataSet(object): | |||||
print("Exception happens at the `{}`th instance.".format(idx)) | print("Exception happens at the `{}`th instance.".format(idx)) | ||||
raise e | raise e | ||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||||
if new_field_name is not None: | if new_field_name is not None: | ||||
self._add_apply_field(results, new_field_name, kwargs) | self._add_apply_field(results, new_field_name, kwargs) | ||||
@@ -707,7 +707,7 @@ class DataSet(object): | |||||
raise e | raise e | ||||
# results = [func(ins) for ins in self._inner_iter()] | # results = [func(ins) for ins in self._inner_iter()] | ||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||||
if new_field_name is not None: | if new_field_name is not None: | ||||
self._add_apply_field(results, new_field_name, kwargs) | self._add_apply_field(results, new_field_name, kwargs) | ||||
@@ -1,5 +1,5 @@ | |||||
""" | """ | ||||
FieldArray是 DataSet_ 中一列的存储方式 | |||||
FieldArray是 DataSet_ 中一列的存储方式,原理部分请参考 DataSet_ 处 | |||||
.. _FieldArray: | .. _FieldArray: | ||||
@@ -11,41 +11,19 @@ from copy import deepcopy | |||||
class FieldArray(object): | class FieldArray(object): | ||||
"""``FieldArray`` is the collection of ``Instance``s of the same field. | |||||
It is the basic element of ``DataSet`` class. | |||||
:param str name: the name of the FieldArray | |||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | |||||
:param bool is_target: If True, this FieldArray is used to compute loss. | |||||
:param bool is_input: If True, this FieldArray is used to the model input. | |||||
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||||
fieldarray.set_pad_val()。 | |||||
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型 | |||||
则会进行padding;(3)其它情况不进行padder。 | |||||
假设需要对English word中character进行padding,则需要使用其他的padder。 | |||||
或ignore_type为True但是需要进行padding。 | |||||
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. | |||||
(default: False) | |||||
""" | |||||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | ||||
"""DataSet在初始化时会有两类方法对FieldArray操作: | |||||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||||
1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]}) | |||||
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | |||||
然后后面的样本使用FieldArray.append进行添加。 | |||||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||||
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 | |||||
ignore_type用来控制是否进行类型检查,如果为True,则不检查。 | |||||
"""FieldArray是用于保存 DataSet_ 中一个field的实体。 | |||||
:param str name: FieldArray的名称 | |||||
:param list,numpy.ndarray content: 列表的元素可以为list,int,float, | |||||
:param bool is_target: 这个field是否是一个target field。 | |||||
:param bool is_input: 这个field是否是一个input field。 | |||||
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||||
fieldarray.set_pad_val()。默认为None,即使用 AutoPadder_ 。 | |||||
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, 就 | |||||
可以设置为True。具体意义请参考 DataSet_ 。 | |||||
""" | """ | ||||
self.name = name | self.name = name | ||||
if isinstance(content, list): | if isinstance(content, list): | ||||
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | ||||
@@ -211,10 +189,10 @@ class FieldArray(object): | |||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | ||||
def append(self, val): | def append(self, val): | ||||
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为 | |||||
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。 | |||||
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | |||||
的内容是匹配的。 | |||||
:param val: Any. | |||||
:param Any val: 需要append的值。 | |||||
""" | """ | ||||
if self.ignore_type is False: | if self.ignore_type is False: | ||||
if isinstance(val, list): | if isinstance(val, list): | ||||
@@ -262,8 +240,8 @@ class FieldArray(object): | |||||
def get(self, indices, pad=True): | def get(self, indices, pad=True): | ||||
"""根据给定的indices返回内容 | """根据给定的indices返回内容 | ||||
:param indices: (int, List[int]), 获取indices对应的内容。 | |||||
:param pad: bool, 是否对返回的结果进行padding。仅对indices为List[int]时有效 | |||||
:param int,list(int) indices:, 获取indices对应的内容。 | |||||
:param bool pad: , 是否对返回的结果进行padding。仅对indices为List[int]时有效 | |||||
:return: (single, List) | :return: (single, List) | ||||
""" | """ | ||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
@@ -281,7 +259,7 @@ class FieldArray(object): | |||||
""" | """ | ||||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | ||||
:param padder: (None, Padder). 设置为None即删除padder. | |||||
:param None,Padder padder:. 设置为None即删除padder。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if padder is not None: | if padder is not None: | ||||
@@ -293,7 +271,7 @@ class FieldArray(object): | |||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
"""修改padder的pad_val. | """修改padder的pad_val. | ||||
:param pad_val: int。将该field的pad值设置为该值 | |||||
:param int pad_val: 该field的pad值设置为该值。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.padder is not None: | if self.padder is not None: | ||||
@@ -312,8 +290,8 @@ class FieldArray(object): | |||||
""" | """ | ||||
将other的属性复制给本FieldArray(other必须为FieldArray类型).属性包括 is_input, is_target, padder, ignore_type | 将other的属性复制给本FieldArray(other必须为FieldArray类型).属性包括 is_input, is_target, padder, ignore_type | ||||
:param other: FieldArray | |||||
:return: | |||||
:param FieldArray other: 从哪个field拷贝属性 | |||||
:return: FieldArray | |||||
""" | """ | ||||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | ||||
@@ -324,7 +302,7 @@ class FieldArray(object): | |||||
return self | return self | ||||
def is_iterable(content): | |||||
def _is_iterable(content): | |||||
try: | try: | ||||
_ = (e for e in content) | _ = (e for e in content) | ||||
except TypeError: | except TypeError: | ||||
@@ -350,11 +328,10 @@ class Padder: | |||||
""" | """ | ||||
传入的是List内容。假设有以下的DataSet。 | 传入的是List内容。假设有以下的DataSet。 | ||||
:param contents: List[element]。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||||
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||||
deepcopy一份。 | deepcopy一份。 | ||||
:param field_name: str, field的名称。 | |||||
:param field_ele_dtype: (np.int64, np.float64, np.str, None), 该field的内层元素的类型。如果该field的ignore_type | |||||
为True,该这个值为None。 | |||||
:param str, field_name: field的名称。 | |||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||||
:return: np.array([padded_element]) | :return: np.array([padded_element]) | ||||
Example:: | Example:: | ||||
@@ -400,10 +377,10 @@ class AutoPadder(Padder): | |||||
2 如果元素类型为(np.int64, np.float64), | 2 如果元素类型为(np.int64, np.float64), | ||||
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding | |||||
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding | |||||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
如果某个instance中field为[1, 2, 3],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | |||||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | |||||
""" | """ | ||||
def __init__(self, pad_val=0): | def __init__(self, pad_val=0): | ||||
@@ -427,7 +404,7 @@ class AutoPadder(Padder): | |||||
return False | return False | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
if not is_iterable(contents[0]): | |||||
if not _is_iterable(contents[0]): | |||||
array = np.array([content for content in contents], dtype=field_ele_dtype) | array = np.array([content for content in contents], dtype=field_ele_dtype) | ||||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | ||||
max_len = max([len(content) for content in contents]) | max_len = max([len(content) for content in contents]) | ||||
@@ -454,7 +431,7 @@ class EngChar2DPadder(Padder): | |||||
Example:: | Example:: | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import EnChar2DPadder | |||||
from fastNLP import EngChar2DPadder | |||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']}) | dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']}) | ||||
dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars') | dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars') | ||||
@@ -462,14 +439,15 @@ class EngChar2DPadder(Padder): | |||||
vocab.from_dataset(dataset, field_name='chars') | vocab.from_dataset(dataset, field_name='chars') | ||||
vocab.index_dataset(dataset, field_name='chars') | vocab.index_dataset(dataset, field_name='chars') | ||||
dataset.set_input('chars') | dataset.set_input('chars') | ||||
padder = EnChar2DPadder() | |||||
padder = EngChar2DPadder() | |||||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | ||||
""" | """ | ||||
def __init__(self, pad_val=0, pad_length=0): | def __init__(self, pad_val=0, pad_length=0): | ||||
""" | """ | ||||
:param pad_val: int, pad的位置使用该index | :param pad_val: int, pad的位置使用该index | ||||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度都pad或截 | |||||
取到该长度. | |||||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度 | |||||
都pad或截取到该长度. | |||||
""" | """ | ||||
super().__init__(pad_val=pad_val) | super().__init__(pad_val=pad_val) | ||||
@@ -494,7 +472,7 @@ class EngChar2DPadder(Padder): | |||||
except: | except: | ||||
raise ValueError("Field:{} only has two dimensions.".format(field_name)) | raise ValueError("Field:{} only has two dimensions.".format(field_name)) | ||||
if is_iterable(value): | |||||
if _is_iterable(value): | |||||
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
@@ -3,34 +3,33 @@ Instance文档 | |||||
.. _Instance: | .. _Instance: | ||||
测试 | |||||
Instance是fastNLP中对应于一个sample的类。一个sample可以认为是fastNLP中的一个Instance对象。一个具像化的表示类似与 DataSet_ | |||||
出那个表中所展示的一行。 | |||||
""" | """ | ||||
class Instance(object): | class Instance(object): | ||||
"""An Instance is an example of data. | |||||
Example:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
ins["field_1"] | |||||
>>[1, 1, 1] | |||||
ins.add_field("field_3", [3, 3, 3]) | |||||
""" | |||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
""" | |||||
"""Instance的初始化如下面的Example所示 | |||||
Example:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
ins["field_1"] | |||||
>>[1, 1, 1] | |||||
ins.add_field("field_3", [3, 3, 3]) | |||||
:param fields: 可能是一维或者二维的 list or np.array | |||||
ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) | |||||
""" | """ | ||||
self.fields = fields | self.fields = fields | ||||
def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
"""Add a new field to the instance. | |||||
"""向Instance中增加一个field | |||||
:param field_name: str, the name of the field. | |||||
:param str field_name: 新增field的名称 | |||||
:param Any field: 新增field的内容 | |||||
""" | """ | ||||
self.fields[field_name] = field | self.fields[field_name] = field | ||||
@@ -12,12 +12,12 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _CheckRes | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import _check_function_or_method | from fastNLP.core.utils import _check_function_or_method | ||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _get_func_signature | |||||
class LossBase(object): | class LossBase(object): | ||||
@@ -70,7 +70,7 @@ class LossBase(object): | |||||
for func_param, input_param in self.param_map.items(): | for func_param, input_param in self.param_map.items(): | ||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the " | |||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | |||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
# evaluate should not have varargs. | # evaluate should not have varargs. | ||||
@@ -111,7 +111,7 @@ class LossBase(object): | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | func_args = set([arg for arg in func_spect.args if arg != 'self']) | ||||
for func_arg, input_arg in self.param_map.items(): | for func_arg, input_arg in self.param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.") | |||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | |||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
@@ -151,16 +151,16 @@ class LossBase(object): | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
check_res = CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
check_res = _CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise CheckError(check_res=check_res, | |||||
func_signature=get_func_signature(self.get_loss)) | |||||
raise _CheckError(check_res=check_res, | |||||
func_signature=_get_func_signature(self.get_loss)) | |||||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | ||||
loss = self.get_loss(**refined_args) | loss = self.get_loss(**refined_args) | ||||
@@ -289,14 +289,14 @@ class LossInForward(LossBase): | |||||
def get_loss(self, **kwargs): | def get_loss(self, **kwargs): | ||||
if self.loss_key not in kwargs: | if self.loss_key not in kwargs: | ||||
check_res = CheckRes( | |||||
check_res = _CheckRes( | |||||
missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"], | missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"], | ||||
unused=[], | unused=[], | ||||
duplicated=[], | duplicated=[], | ||||
required=[], | required=[], | ||||
all_needed=[], | all_needed=[], | ||||
varargs=[]) | varargs=[]) | ||||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||||
raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss)) | |||||
return kwargs[self.loss_key] | return kwargs[self.loss_key] | ||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
@@ -13,11 +13,11 @@ from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _CheckRes | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _get_func_signature | |||||
from fastNLP.core.utils import seq_lens_to_masks | from fastNLP.core.utils import seq_lens_to_masks | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
@@ -161,7 +161,7 @@ class MetricBase(object): | |||||
for func_param, input_param in self.param_map.items(): | for func_param, input_param in self.param_map.items(): | ||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | |||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
def _fast_param_map(self, pred_dict, target_dict): | def _fast_param_map(self, pred_dict, target_dict): | ||||
@@ -207,7 +207,7 @@ class MetricBase(object): | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | func_args = set([arg for arg in func_spect.args if arg != 'self']) | ||||
for func_arg, input_arg in self.param_map.items(): | for func_arg, input_arg in self.param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | |||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | |||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
@@ -248,16 +248,16 @@ class MetricBase(object): | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
check_res = CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
check_res = _CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise CheckError(check_res=check_res, | |||||
func_signature=get_func_signature(self.evaluate)) | |||||
raise _CheckError(check_res=check_res, | |||||
func_signature=_get_func_signature(self.evaluate)) | |||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | ||||
self.evaluate(**refined_args) | self.evaluate(**refined_args) | ||||
@@ -294,14 +294,14 @@ class AccuracyMetric(MetricBase): | |||||
""" | """ | ||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | ||||
if not isinstance(pred, torch.Tensor): | if not isinstance(pred, torch.Tensor): | ||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | f"got {type(pred)}.") | ||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if seq_len is not None and not isinstance(seq_len, torch.Tensor): | if seq_len is not None and not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
if seq_len is not None: | if seq_len is not None: | ||||
@@ -314,7 +314,7 @@ class AccuracyMetric(MetricBase): | |||||
elif len(pred.size()) == len(target.size()) + 1: | elif len(pred.size()) == len(target.size()) + 1: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
else: | else: | ||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
@@ -516,14 +516,14 @@ class SpanFPreRecMetric(MetricBase): | |||||
:return: | :return: | ||||
""" | """ | ||||
if not isinstance(pred, torch.Tensor): | if not isinstance(pred, torch.Tensor): | ||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | f"got {type(pred)}.") | ||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if not isinstance(seq_len, torch.Tensor): | if not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if pred.size() == target.size() and len(target.size()) == 2: | if pred.size() == target.size() and len(target.size()) == 2: | ||||
@@ -535,7 +535,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | ||||
"id >= {}, the number of classes.".format(num_classes)) | "id >= {}, the number of classes.".format(num_classes)) | ||||
else: | else: | ||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
@@ -714,14 +714,14 @@ class BMESF1PreRecMetric(MetricBase): | |||||
:return: | :return: | ||||
""" | """ | ||||
if not isinstance(pred, torch.Tensor): | if not isinstance(pred, torch.Tensor): | ||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | f"got {type(pred)}.") | ||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if not isinstance(seq_len, torch.Tensor): | if not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if pred.size() == target.size() and len(target.size()) == 2: | if pred.size() == target.size() and len(target.size()) == 2: | ||||
@@ -729,7 +729,7 @@ class BMESF1PreRecMetric(MetricBase): | |||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
else: | else: | ||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
@@ -5,11 +5,11 @@ from fastNLP.core.batch import Batch | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_loss_evaluate | from fastNLP.core.utils import _check_loss_evaluate | ||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _get_func_signature | |||||
from fastNLP.core.utils import _get_device | from fastNLP.core.utils import _get_device | ||||
@@ -75,19 +75,19 @@ class Tester(object): | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | ||||
pred_dict = self._data_forward(self._predict_func, batch_x) | pred_dict = self._data_forward(self._predict_func, batch_x) | ||||
if not isinstance(pred_dict, dict): | if not isinstance(pred_dict, dict): | ||||
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " | |||||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||||
f"must be `dict`, got {type(pred_dict)}.") | f"must be `dict`, got {type(pred_dict)}.") | ||||
for metric in self.metrics: | for metric in self.metrics: | ||||
metric(pred_dict, batch_y) | metric(pred_dict, batch_y) | ||||
for metric in self.metrics: | for metric in self.metrics: | ||||
eval_result = metric.get_metric() | eval_result = metric.get_metric() | ||||
if not isinstance(eval_result, dict): | if not isinstance(eval_result, dict): | ||||
raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be " | |||||
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " | |||||
f"`dict`, got {type(eval_result)}") | f"`dict`, got {type(eval_result)}") | ||||
metric_name = metric.__class__.__name__ | metric_name = metric.__class__.__name__ | ||||
eval_results[metric_name] = eval_result | eval_results[metric_name] = eval_result | ||||
except CheckError as e: | |||||
prev_func_signature = get_func_signature(self._predict_func) | |||||
except _CheckError as e: | |||||
prev_func_signature = _get_func_signature(self._predict_func) | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | ||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | ||||
dataset=self.data, check_level=0) | dataset=self.data, check_level=0) | ||||
@@ -85,17 +85,17 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_ | 对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_ | ||||
使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。 | 使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。 | ||||
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 | |||||
LossInForward_ 这个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor, | |||||
并使用它作为loss。 如果Trainer初始化没有提供loss则使用这个loss TODO 补充一个例子 | |||||
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 LossInForward_ 这 | |||||
个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor, | |||||
并使用它作为loss。 如果Trainer初始化没有提供loss则默认使用 LossInForward_ 。详细例子可以参照 TODO 补充一个例子 | |||||
3. Metric | 3. Metric | ||||
Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。 | Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。 | ||||
在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法, | 在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法, | ||||
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中的input的选择 | |||||
出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子 | |||||
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中被设置为input | |||||
的field中选择出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子 | |||||
2. Trainer的代码检查 | 2. Trainer的代码检查 | ||||
@@ -112,12 +112,12 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, x, b): | |||||
loss = torch.mean((self.fc(x)-b)**2) | |||||
return {'loss': loss} | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, x, b): | |||||
loss = torch.mean((self.fc(x)-b)**2) | |||||
return {'loss': loss} | |||||
model = Model() | model = Model() | ||||
dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) | dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) | ||||
@@ -138,16 +138,15 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
# unused field: ['a'] | # unused field: ['a'] | ||||
# Suggestion: You need to provide ['x'] in DataSet and set it as input. | # Suggestion: You need to provide ['x'] in DataSet and set it as input. | ||||
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里由两类 | |||||
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类 | |||||
信息可以为你提供参考 | 信息可以为你提供参考 | ||||
1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里 | 1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里 | ||||
因为train dataset没有target所以没有显示。根据这里你可以看出是否正确将需要的内容设置为了input或target。 | |||||
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。 | |||||
2. 如果出现了映射错误,出现NameError。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断 | |||||
出当前是在调取forward出错),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能 | |||||
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x' | |||||
,或者model的参数从'x'修改为'a'都可以解决问题。 | |||||
2. NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断 | |||||
出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能 | |||||
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。 | |||||
下面的例子是由于loss计算的时候找不到需要的值 | 下面的例子是由于loss计算的时候找不到需要的值 | ||||
@@ -249,18 +248,18 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
# (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a). | # (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a). | ||||
报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation | 报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation | ||||
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation的弄错的情况。这里的修改是通过在初始化metric的时候 | |||||
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output'). | |||||
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候 | |||||
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。 | |||||
可以通过check_code_level调节检查的强度。默认为0,即进行检查。 | 可以通过check_code_level调节检查的强度。默认为0,即进行检查。 | ||||
3. Trainer与callback | 3. Trainer与callback | ||||
虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。 | 虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。 | ||||
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的类,所有的 Callback_ 都具有 | |||||
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的函数集合,所有的 Callback_ 都具有 | |||||
on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。 | on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。 | ||||
我们将Train.train()这个函数内部分为以下的阶段 | |||||
我们将Train.train()这个函数内部分为以下的阶段,在对应阶段会触发相应的调用。 | |||||
Example:: | Example:: | ||||
@@ -305,7 +304,7 @@ import warnings | |||||
try: | try: | ||||
from tqdm.autonotebook import tqdm | from tqdm.autonotebook import tqdm | ||||
except: | except: | ||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackManager, CallbackException | from fastNLP.core.callback import CallbackManager, CallbackException | ||||
@@ -316,12 +315,12 @@ from fastNLP.core.sampler import Sampler | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_forward_error | from fastNLP.core.utils import _check_forward_error | ||||
from fastNLP.core.utils import _check_loss_evaluate | from fastNLP.core.utils import _check_loss_evaluate | ||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _get_func_signature | |||||
from fastNLP.core.utils import _get_device | from fastNLP.core.utils import _get_device | ||||
@@ -466,34 +465,11 @@ class Trainer(object): | |||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
开始训练过程。主要有以下几个步骤:: | |||||
for epoch in range(num_epochs): | |||||
# 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。 | |||||
非float,int类型的参数将不会被转换为Tensor,且不进行padding。 | |||||
for batch_x, batch_y in Batch(DataSet) | |||||
# batch_x是一个dict, 被设为input的field会出现在这个dict中, | |||||
key为DataSet中的field_name, value为该field的value | |||||
# batch_y也是一个dict,被设为target的field会出现在这个dict中, | |||||
key为DataSet中的field_name, value为该field的value | |||||
2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形 | |||||
参完成参数传递。例如, | |||||
forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的 | |||||
value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数 | |||||
而且默认参数这个key没有在batch_x中,则使用默认参数。 | |||||
3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况 | |||||
中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或 | |||||
target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计 | |||||
算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target | |||||
, 并完成loss的计算。 | |||||
4. 获取到loss之后,进行反向求导并更新梯度 | |||||
根据需要适时进行验证机测试 | |||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||||
使用该函数使Trainer开始训练。 | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | 最好的模型参数。 | ||||
:return results: 返回一个字典类型的数据, | |||||
:return dict: 返回一个字典类型的数据, | |||||
内含以下内容:: | 内含以下内容:: | ||||
seconds: float, 表示训练时长 | seconds: float, 表示训练时长 | ||||
@@ -547,7 +523,7 @@ class Trainer(object): | |||||
def _train(self): | def _train(self): | ||||
if not self.use_tqdm: | if not self.use_tqdm: | ||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||||
else: | else: | ||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
self.step = 0 | self.step = 0 | ||||
@@ -559,6 +535,7 @@ class Trainer(object): | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
self.batch_per_epoch = data_iterator.num_batches | |||||
for epoch in range(1, self.n_epochs + 1): | for epoch in range(1, self.n_epochs + 1): | ||||
self.epoch = epoch | self.epoch = epoch | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
@@ -660,7 +637,7 @@ class Trainer(object): | |||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
y = network(**x) | y = network(**x) | ||||
if not isinstance(y, dict): | if not isinstance(y, dict): | ||||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||||
raise TypeError(f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||||
return y | return y | ||||
def _grad_backward(self, loss): | def _grad_backward(self, loss): | ||||
@@ -796,7 +773,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
pred_dict = model(**refined_batch_x) | pred_dict = model(**refined_batch_x) | ||||
func_signature = get_func_signature(model.forward) | |||||
func_signature = _get_func_signature(model.forward) | |||||
if not isinstance(pred_dict, dict): | if not isinstance(pred_dict, dict): | ||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | ||||
@@ -807,16 +784,16 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
if batch_count == 0: | if batch_count == 0: | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise TypeError( | raise TypeError( | ||||
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||||
f"The return value of {_get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||||
f"but got `{type(loss)}`.") | f"but got `{type(loss)}`.") | ||||
if len(loss.size()) != 0: | if len(loss.size()) != 0: | ||||
raise ValueError( | raise ValueError( | ||||
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " | |||||
f"The size of return value of {_get_func_signature(losser.get_loss)} is {loss.size()}, " | |||||
f"should be torch.size([])") | f"should be torch.size([])") | ||||
loss.backward() | loss.backward() | ||||
except CheckError as e: | |||||
# TODO: another error raised if CheckError caught | |||||
pre_func_signature = get_func_signature(model.forward) | |||||
except _CheckError as e: | |||||
# TODO: another error raised if _CheckError caught | |||||
pre_func_signature = _get_func_signature(model.forward) | |||||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | ||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | ||||
dataset=dataset, check_level=check_level) | dataset=dataset, check_level=check_level) | ||||
@@ -9,7 +9,7 @@ import numpy as np | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||||
'varargs']) | 'varargs']) | ||||
def _prepare_cache_filepath(filepath): | def _prepare_cache_filepath(filepath): | ||||
@@ -28,6 +28,57 @@ def _prepare_cache_filepath(filepath): | |||||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | # TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | ||||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
""" | |||||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用 | |||||
Example:: | |||||
import time | |||||
import numpy as np | |||||
from fastNLP import cache_results | |||||
@cache_results('cache.pkl') | |||||
def process_data(): | |||||
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | |||||
time.sleep(1) | |||||
return np.random.randint(5, size=(10, 20)) | |||||
start_time = time.time() | |||||
process_data() | |||||
print(time.time() - start_time) | |||||
start_time = time.time() | |||||
process_data() | |||||
print(time.time() - start_time) | |||||
# 输出内容如下 | |||||
# Save cache to cache.pkl. | |||||
# 1.0015439987182617 | |||||
# Read cache from cache.pkl. | |||||
# 0.00013065338134765625 | |||||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理 | |||||
Example:: | |||||
# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | |||||
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' | |||||
上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | |||||
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | |||||
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称。 | |||||
Example:: | |||||
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 | |||||
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | |||||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | |||||
函数调用的时候传入_cache_fp这个参数。 | |||||
:param bool _refresh: 是否重新生成cache。 | |||||
:param int _verbose: 是否打印cache的信息。 | |||||
:return: | |||||
""" | |||||
def wrapper_(func): | def wrapper_(func): | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
for key, _ in signature.parameters.items(): | for key, _ in signature.parameters.items(): | ||||
@@ -74,48 +125,48 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
return wrapper | return wrapper | ||||
return wrapper_ | return wrapper_ | ||||
def save_pickle(obj, pickle_path, file_name): | |||||
"""Save an object into a pickle file. | |||||
:param obj: an object | |||||
:param pickle_path: str, the directory where the pickle file is to be saved | |||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.mkdir(pickle_path) | |||||
print("make dir {} before saving pickle file".format(pickle_path)) | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
_pickle.dump(obj, f) | |||||
print("{} saved in {}".format(file_name, pickle_path)) | |||||
def load_pickle(pickle_path, file_name): | |||||
"""Load an object from a given pickle file. | |||||
:param pickle_path: str, the directory where the pickle file is. | |||||
:param file_name: str, the name of the pickle file. | |||||
:return obj: an object stored in the pickle | |||||
""" | |||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
obj = _pickle.load(f) | |||||
print("{} loaded from {}".format(file_name, pickle_path)) | |||||
return obj | |||||
def pickle_exist(pickle_path, pickle_name): | |||||
"""Check if a given pickle file exists in the directory. | |||||
:param pickle_path: the directory of target pickle file | |||||
:param pickle_name: the filename of target pickle file | |||||
:return: True if file exists else False | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.makedirs(pickle_path) | |||||
file_name = os.path.join(pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False | |||||
# def save_pickle(obj, pickle_path, file_name): | |||||
# """Save an object into a pickle file. | |||||
# | |||||
# :param obj: an object | |||||
# :param pickle_path: str, the directory where the pickle file is to be saved | |||||
# :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||||
# """ | |||||
# if not os.path.exists(pickle_path): | |||||
# os.mkdir(pickle_path) | |||||
# print("make dir {} before saving pickle file".format(pickle_path)) | |||||
# with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
# _pickle.dump(obj, f) | |||||
# print("{} saved in {}".format(file_name, pickle_path)) | |||||
# | |||||
# | |||||
# def load_pickle(pickle_path, file_name): | |||||
# """Load an object from a given pickle file. | |||||
# | |||||
# :param pickle_path: str, the directory where the pickle file is. | |||||
# :param file_name: str, the name of the pickle file. | |||||
# :return obj: an object stored in the pickle | |||||
# """ | |||||
# with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
# obj = _pickle.load(f) | |||||
# print("{} loaded from {}".format(file_name, pickle_path)) | |||||
# return obj | |||||
# | |||||
# | |||||
# def pickle_exist(pickle_path, pickle_name): | |||||
# """Check if a given pickle file exists in the directory. | |||||
# | |||||
# :param pickle_path: the directory of target pickle file | |||||
# :param pickle_name: the filename of target pickle file | |||||
# :return: True if file exists else False | |||||
# """ | |||||
# if not os.path.exists(pickle_path): | |||||
# os.makedirs(pickle_path) | |||||
# file_name = os.path.join(pickle_path, pickle_name) | |||||
# if os.path.exists(file_name): | |||||
# return True | |||||
# else: | |||||
# return False | |||||
def _get_device(device, check_exist=False): | def _get_device(device, check_exist=False): | ||||
""" | """ | ||||
@@ -232,15 +283,15 @@ def _check_arg_dict_list(func, args): | |||||
missing = list(require_args - input_args) | missing = list(require_args - input_args) | ||||
unused = list(input_args - all_args) | unused = list(input_args - all_args) | ||||
varargs = [] if not spect.varargs else [spect.varargs] | varargs = [] if not spect.varargs else [spect.varargs] | ||||
return CheckRes(missing=missing, | |||||
unused=unused, | |||||
duplicated=duplicated, | |||||
required=list(require_args), | |||||
all_needed=list(all_args), | |||||
varargs=varargs) | |||||
return _CheckRes(missing=missing, | |||||
unused=unused, | |||||
duplicated=duplicated, | |||||
required=list(require_args), | |||||
all_needed=list(all_args), | |||||
varargs=varargs) | |||||
def get_func_signature(func): | |||||
def _get_func_signature(func): | |||||
""" | """ | ||||
Given a function or method, return its signature. | Given a function or method, return its signature. | ||||
@@ -318,13 +369,13 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||||
raise TypeError("Only support `dict` type right now.") | raise TypeError("Only support `dict` type right now.") | ||||
class CheckError(Exception): | |||||
class _CheckError(Exception): | |||||
""" | """ | ||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | |||||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | |||||
""" | """ | ||||
def __init__(self, check_res: CheckRes, func_signature: str): | |||||
def __init__(self, check_res: _CheckRes, func_signature: str): | |||||
errs = [f'Problems occurred when calling `{func_signature}`'] | errs = [f'Problems occurred when calling `{func_signature}`'] | ||||
if check_res.varargs: | if check_res.varargs: | ||||
@@ -347,7 +398,7 @@ WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | STRICT_CHECK_LEVEL = 2 | ||||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes, | |||||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: _CheckRes, | |||||
pred_dict: dict, target_dict: dict, dataset, check_level=0): | pred_dict: dict, target_dict: dict, dataset, check_level=0): | ||||
errs = [] | errs = [] | ||||
unuseds = [] | unuseds = [] | ||||
@@ -449,7 +500,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | def _check_forward_error(forward_func, batch_x, dataset, check_level): | ||||
check_res = _check_arg_dict_list(forward_func, batch_x) | check_res = _check_arg_dict_list(forward_func, batch_x) | ||||
func_signature = get_func_signature(forward_func) | |||||
func_signature = _get_func_signature(forward_func) | |||||
errs = [] | errs = [] | ||||
suggestions = [] | suggestions = [] | ||||
@@ -543,7 +594,7 @@ def seq_mask(seq_len, max_len): | |||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | return torch.gt(seq_len, seq_range) # [batch_size, max_len] | ||||
class pseudo_tqdm: | |||||
class _pseudo_tqdm: | |||||
""" | """ | ||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
@@ -2,7 +2,7 @@ from functools import wraps | |||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
def check_build_vocab(func): | |||||
def _check_build_vocab(func): | |||||
"""A decorator to make sure the indexing is built before used. | """A decorator to make sure the indexing is built before used. | ||||
""" | """ | ||||
@@ -15,7 +15,7 @@ def check_build_vocab(func): | |||||
return _wrapper | return _wrapper | ||||
def check_build_status(func): | |||||
def _check_build_status(func): | |||||
"""A decorator to check whether the vocabulary updates after the last build. | """A decorator to check whether the vocabulary updates after the last build. | ||||
""" | """ | ||||
@@ -67,7 +67,7 @@ class Vocabulary(object): | |||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | self.rebuild = True | ||||
@check_build_status | |||||
@_check_build_status | |||||
def update(self, word_lst): | def update(self, word_lst): | ||||
"""依次增加序列中词在词典中的出现频率 | """依次增加序列中词在词典中的出现频率 | ||||
@@ -75,7 +75,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
@check_build_status | |||||
@_check_build_status | |||||
def add(self, word): | def add(self, word): | ||||
""" | """ | ||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
@@ -84,7 +84,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
@check_build_status | |||||
@_check_build_status | |||||
def add_word(self, word): | def add_word(self, word): | ||||
""" | """ | ||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
@@ -93,7 +93,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.add(word) | self.add(word) | ||||
@check_build_status | |||||
@_check_build_status | |||||
def add_word_lst(self, word_lst): | def add_word_lst(self, word_lst): | ||||
""" | """ | ||||
依次增加序列中词在词典中的出现频率 | 依次增加序列中词在词典中的出现频率 | ||||
@@ -132,11 +132,11 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.idx2word = {i: w for w, i in self.word2idx.items()} | self.idx2word = {i: w for w, i in self.word2idx.items()} | ||||
@check_build_vocab | |||||
@_check_build_vocab | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.word2idx) | return len(self.word2idx) | ||||
@check_build_vocab | |||||
@_check_build_vocab | |||||
def __contains__(self, item): | def __contains__(self, item): | ||||
""" | """ | ||||
检查词是否被记录 | 检查词是否被记录 | ||||
@@ -161,7 +161,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
return self.__contains__(w) | return self.__contains__(w) | ||||
@check_build_vocab | |||||
@_check_build_vocab | |||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
""" | """ | ||||
To support usage like:: | To support usage like:: | ||||
@@ -175,7 +175,7 @@ class Vocabulary(object): | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@check_build_vocab | |||||
@_check_build_vocab | |||||
def index_dataset(self, *datasets, field_name, new_field_name=None): | def index_dataset(self, *datasets, field_name, new_field_name=None): | ||||
""" | """ | ||||
将DataSet中对应field的词转为数字. | 将DataSet中对应field的词转为数字. | ||||
@@ -275,7 +275,7 @@ class Vocabulary(object): | |||||
return self.__getitem__(w) | return self.__getitem__(w) | ||||
@property | @property | ||||
@check_build_vocab | |||||
@_check_build_vocab | |||||
def unknown_idx(self): | def unknown_idx(self): | ||||
""" | """ | ||||
unknown 对应的数字. | unknown 对应的数字. | ||||
@@ -285,7 +285,7 @@ class Vocabulary(object): | |||||
return self.word2idx[self.unknown] | return self.word2idx[self.unknown] | ||||
@property | @property | ||||
@check_build_vocab | |||||
@_check_build_vocab | |||||
def padding_idx(self): | def padding_idx(self): | ||||
""" | """ | ||||
padding 对应的数字 | padding 对应的数字 | ||||
@@ -294,7 +294,7 @@ class Vocabulary(object): | |||||
return None | return None | ||||
return self.word2idx[self.padding] | return self.word2idx[self.padding] | ||||
@check_build_vocab | |||||
@_check_build_vocab | |||||
def to_word(self, idx): | def to_word(self, idx): | ||||
""" | """ | ||||
给定一个数字, 将其转为对应的词. | 给定一个数字, 将其转为对应的词. | ||||
@@ -13,12 +13,12 @@ from torch import nn | |||||
try: | try: | ||||
from tqdm.autonotebook import tqdm | from tqdm.autonotebook import tqdm | ||||
except: | except: | ||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackManager, CallbackException | from fastNLP.core.callback import CallbackManager, CallbackException | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
import fastNLP | import fastNLP | ||||
import fastNLP.models.enas_utils as utils | import fastNLP.models.enas_utils as utils | ||||
@@ -118,7 +118,7 @@ class ENASTrainer(fastNLP.Trainer): | |||||
def _train(self): | def _train(self): | ||||
if not self.use_tqdm: | if not self.use_tqdm: | ||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||||
else: | else: | ||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
self.step = 0 | self.step = 0 | ||||
@@ -8,7 +8,7 @@ import numpy as np | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
import time | import time | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||