@@ -1,5 +1,15 @@ | |||||
""" | |||||
fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.modules` 等子模块组成,但常用的组件都可以直接 import ,常用组件如下: | |||||
""" | |||||
__all__ = ["Instance", "FieldArray", "Batch", "Vocabulary", "DataSet", | |||||
"Trainer", "Tester", "Callback", | |||||
"Padder", "AutoPadder", "EngChar2DPadder", | |||||
"AccuracyMetric", "Optimizer", "SGD", "Adam", | |||||
"Sampler", "SequentialSampler", "BucketSampler", "RandomSampler", | |||||
"LossFunc", "CrossEntropyLoss", "L1Loss", "BCELoss", "NLLLoss", "LossInForward", | |||||
"cache_results"] | |||||
from .core import * | from .core import * | ||||
from . import models | from . import models | ||||
from . import modules | from . import modules | ||||
__version__ = '0.4.0' | |||||
__version__ = '0.4.0' |
@@ -1,87 +1,89 @@ | |||||
""" | """ | ||||
Callback的说明文档 | |||||
.. _Callback: | |||||
Callback是fastNLP中被设计用于增强 Trainer_ 的类。如果Callback被传递给了 Trainer_ , 则 Trainer_ 会在对应的阶段调用Callback | |||||
的函数,具体调用时机可以通过 Trainer_ 查看。 | |||||
callback模块实现了 fastNLP 中的Callback类,用于增强 :class:`~fastNLP.Trainer` 类, | |||||
关于Trainer的详细文档,请参见 :doc:`trainer 模块<fastNLP.core.trainer>` | |||||
""" | """ | ||||
import os | import os | ||||
import torch | import torch | ||||
from fastNLP.io.model_io import ModelSaver, ModelLoader | |||||
from ..io.model_io import ModelSaver, ModelLoader | |||||
try: | try: | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
except: | except: | ||||
pass | pass | ||||
class Callback(object): | |||||
"""这是Callback的基类,所有的callback必须继承自这个类。 | |||||
class Callback(object): | |||||
""" | """ | ||||
别名::class:`fastNLP.Callback` :class:`fastNLP.core.callback.Callback` | |||||
Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 | |||||
如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, | |||||
具体调用时机可以通过 :doc:`trainer 模块<fastNLP.core.trainer>` 查看。 | |||||
这是Callback的基类,所有的callback必须继承自这个类(参见 :doc:`callback 模块 <fastNLP.core.callback>` ) | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self._trainer = None # 在Trainer内部被重新赋值 | self._trainer = None # 在Trainer内部被重新赋值 | ||||
@property | @property | ||||
def trainer(self): | def trainer(self): | ||||
""" | """ | ||||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | 该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | ||||
""" | """ | ||||
return self._trainer | return self._trainer | ||||
@property | @property | ||||
def step(self): | def step(self): | ||||
"""当前运行到的step, 范围为[1, self.n_steps+1)""" | """当前运行到的step, 范围为[1, self.n_steps+1)""" | ||||
return self._trainer.step | return self._trainer.step | ||||
@property | @property | ||||
def n_steps(self): | def n_steps(self): | ||||
"""Trainer一共会运行多少步""" | """Trainer一共会运行多少步""" | ||||
return self._trainer.n_steps | return self._trainer.n_steps | ||||
@property | @property | ||||
def batch_size(self): | def batch_size(self): | ||||
"""train和evaluate时的batch_size为多大""" | """train和evaluate时的batch_size为多大""" | ||||
return self._trainer.batch_size | return self._trainer.batch_size | ||||
@property | @property | ||||
def epoch(self): | def epoch(self): | ||||
"""当前运行的epoch数,范围是[1, self.n_epochs+1)""" | """当前运行的epoch数,范围是[1, self.n_epochs+1)""" | ||||
return self._trainer.epoch | return self._trainer.epoch | ||||
@property | @property | ||||
def n_epochs(self): | def n_epochs(self): | ||||
"""一共会运行多少个epoch""" | """一共会运行多少个epoch""" | ||||
return self._trainer.n_epochs | return self._trainer.n_epochs | ||||
@property | @property | ||||
def optimizer(self): | def optimizer(self): | ||||
"""初始化Trainer时传递的Optimizer""" | """初始化Trainer时传递的Optimizer""" | ||||
return self._trainer.optimizer | return self._trainer.optimizer | ||||
@property | @property | ||||
def model(self): | def model(self): | ||||
"""正在被Trainer训练的模型""" | """正在被Trainer训练的模型""" | ||||
return self._trainer.model | return self._trainer.model | ||||
@property | @property | ||||
def pbar(self): | def pbar(self): | ||||
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。""" | """如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。""" | ||||
return self._trainer.pbar | return self._trainer.pbar | ||||
@property | @property | ||||
def update_every(self): | def update_every(self): | ||||
"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | """Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | ||||
return self._trainer.update_every | return self._trainer.update_every | ||||
@property | @property | ||||
def batch_per_epoch(self): | def batch_per_epoch(self): | ||||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | """每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | ||||
return self._trainer.batch_per_epoch | return self._trainer.batch_per_epoch | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
""" | """ | ||||
在Train过程开始之前调用。 | 在Train过程开始之前调用。 | ||||
@@ -89,7 +91,7 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
""" | """ | ||||
在每个epoch开始之前调用一次 | 在每个epoch开始之前调用一次 | ||||
@@ -97,7 +99,7 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
""" | """ | ||||
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | 每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | ||||
@@ -110,7 +112,7 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
""" | """ | ||||
在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。 | 在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。 | ||||
@@ -120,7 +122,7 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
""" | """ | ||||
在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。 | 在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。 | ||||
@@ -129,7 +131,7 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_backward_end(self): | def on_backward_end(self): | ||||
""" | """ | ||||
反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。 | 反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。 | ||||
@@ -137,7 +139,7 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_step_end(self): | def on_step_end(self): | ||||
""" | """ | ||||
到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。 | 到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。 | ||||
@@ -145,14 +147,14 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_batch_end(self): | def on_batch_end(self): | ||||
""" | """ | ||||
这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。 | 这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。 | ||||
""" | """ | ||||
pass | pass | ||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
""" | """ | ||||
如果Trainer中设置了验证,则发生验证前会调用该函数 | 如果Trainer中设置了验证,则发生验证前会调用该函数 | ||||
@@ -160,7 +162,7 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
""" | """ | ||||
每次执行验证集的evaluation后会调用。 | 每次执行验证集的evaluation后会调用。 | ||||
@@ -173,19 +175,19 @@ class Callback(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
""" | """ | ||||
每个epoch结束将会调用该方法 | 每个epoch结束将会调用该方法 | ||||
""" | """ | ||||
pass | pass | ||||
def on_train_end(self): | def on_train_end(self): | ||||
""" | """ | ||||
训练结束,调用该方法 | 训练结束,调用该方法 | ||||
""" | """ | ||||
pass | pass | ||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
""" | """ | ||||
当训练过程出现异常,会触发该方法 | 当训练过程出现异常,会触发该方法 | ||||
@@ -196,32 +198,31 @@ class Callback(object): | |||||
def _transfer(func): | def _transfer(func): | ||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | """装饰器,将对CallbackManager的调用转发到各个Callback子类. | ||||
:param func: | :param func: | ||||
:return: | :return: | ||||
""" | """ | ||||
def wrapper(manager, *arg): | def wrapper(manager, *arg): | ||||
returns = [] | returns = [] | ||||
for callback in manager.callbacks: | for callback in manager.callbacks: | ||||
returns.append(getattr(callback, func.__name__)(*arg)) | returns.append(getattr(callback, func.__name__)(*arg)) | ||||
return returns | return returns | ||||
return wrapper | return wrapper | ||||
class CallbackManager(Callback): | class CallbackManager(Callback): | ||||
"""内部使用的Callback管理类 | |||||
""" | |||||
def __init__(self, env, callbacks=None): | def __init__(self, env, callbacks=None): | ||||
""" | """ | ||||
内部使用的Callback管理类 | |||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | ||||
:param List[Callback] callbacks: | :param List[Callback] callbacks: | ||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
# set attribute of trainer environment | # set attribute of trainer environment | ||||
self.callbacks = [] | self.callbacks = [] | ||||
if callbacks is not None: | if callbacks is not None: | ||||
if isinstance(callbacks, list): | if isinstance(callbacks, list): | ||||
@@ -232,78 +233,82 @@ class CallbackManager(Callback): | |||||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | ||||
else: | else: | ||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | ||||
for env_name, env_val in env.items(): | for env_name, env_val in env.items(): | ||||
for callback in self.callbacks: | for callback in self.callbacks: | ||||
setattr(callback, '_'+env_name, env_val) # Callback.trainer | |||||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||||
@_transfer | @_transfer | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_backward_end(self): | def on_backward_end(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_step_end(self): | def on_step_end(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_batch_end(self): | def on_batch_end(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_train_end(self): | def on_train_end(self): | ||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
pass | pass | ||||
class GradientClipCallback(Callback): | class GradientClipCallback(Callback): | ||||
"""每次backward前,将parameter的gradient clip到某个范围。 | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer | |||||
的model中所有参数进行clip | |||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param str clip_type: 支持'norm', 'value'两种:: | |||||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
2 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; | |||||
大于clip_value的gradient被赋值为clip_value. | |||||
""" | |||||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | ||||
"""每次backward前,将parameter的gradient clip到某个范围。 | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer | |||||
的model中所有参数进行clip | |||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param str clip_type: 支持'norm', 'value'两种。 | |||||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
2. 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于 | |||||
clip_value的gradient被赋值为clip_value. | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
from torch import nn | from torch import nn | ||||
if clip_type == 'norm': | if clip_type == 'norm': | ||||
self.clip_fun = nn.utils.clip_grad_norm_ | self.clip_fun = nn.utils.clip_grad_norm_ | ||||
@@ -313,7 +318,7 @@ class GradientClipCallback(Callback): | |||||
raise ValueError("Only supports `norm` or `value` right now.") | raise ValueError("Only supports `norm` or `value` right now.") | ||||
self.parameters = parameters | self.parameters = parameters | ||||
self.clip_value = clip_value | self.clip_value = clip_value | ||||
def on_backward_end(self): | def on_backward_end(self): | ||||
if self.parameters is None: | if self.parameters is None: | ||||
self.clip_fun(self.model.parameters(), self.clip_value) | self.clip_fun(self.model.parameters(), self.clip_value) | ||||
@@ -321,31 +326,17 @@ class GradientClipCallback(Callback): | |||||
self.clip_fun(self.parameters, self.clip_value) | self.clip_fun(self.parameters, self.clip_value) | ||||
class CallbackException(BaseException): | |||||
def __init__(self, msg): | |||||
""" | |||||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | |||||
:param str msg: Exception的信息。 | |||||
""" | |||||
super(CallbackException, self).__init__(msg) | |||||
class EarlyStopError(CallbackException): | |||||
def __init__(self, msg): | |||||
"""用于EarlyStop时从Trainer训练循环中跳出。""" | |||||
super(EarlyStopError, self).__init__(msg) | |||||
class EarlyStopCallback(Callback): | class EarlyStopCallback(Callback): | ||||
def __init__(self, patience): | |||||
""" | |||||
""" | |||||
:param int patience: 多少个epoch没有变好就停止训练 | |||||
""" | |||||
:param int patience: 多少个epoch没有变好就停止训练 | |||||
""" | |||||
def __init__(self, patience): | |||||
super(EarlyStopCallback, self).__init__() | super(EarlyStopCallback, self).__init__() | ||||
self.patience = patience | self.patience = patience | ||||
self.wait = 0 | self.wait = 0 | ||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
if not is_better_eval: | if not is_better_eval: | ||||
# current result is getting worse | # current result is getting worse | ||||
@@ -355,7 +346,7 @@ class EarlyStopCallback(Callback): | |||||
self.wait += 1 | self.wait += 1 | ||||
else: | else: | ||||
self.wait = 0 | self.wait = 0 | ||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
if isinstance(exception, EarlyStopError): | if isinstance(exception, EarlyStopError): | ||||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | print("Early Stopping triggered in epoch {}!".format(self.epoch)) | ||||
@@ -364,39 +355,41 @@ class EarlyStopCallback(Callback): | |||||
class LRScheduler(Callback): | class LRScheduler(Callback): | ||||
def __init__(self, lr_scheduler): | |||||
"""对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||||
Example:: | |||||
"""对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||||
from fastNLP import LRScheduler | |||||
Example:: | |||||
from fastNLP import LRScheduler | |||||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||||
""" | |||||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||||
""" | |||||
def __init__(self, lr_scheduler): | |||||
super(LRScheduler, self).__init__() | super(LRScheduler, self).__init__() | ||||
import torch.optim | import torch.optim | ||||
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | ||||
self.scheduler = lr_scheduler | self.scheduler = lr_scheduler | ||||
else: | else: | ||||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
self.scheduler.step() | self.scheduler.step() | ||||
class ControlC(Callback): | class ControlC(Callback): | ||||
def __init__(self, quit_all): | |||||
""" | |||||
""" | |||||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||||
""" | |||||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||||
""" | |||||
def __init__(self, quit_all): | |||||
super(ControlC, self).__init__() | super(ControlC, self).__init__() | ||||
if type(quit_all) != bool: | if type(quit_all) != bool: | ||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | ||||
self.quit_all = quit_all | self.quit_all = quit_all | ||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
if isinstance(exception, KeyboardInterrupt): | if isinstance(exception, KeyboardInterrupt): | ||||
if self.quit_all is True: | if self.quit_all is True: | ||||
@@ -412,7 +405,7 @@ class SmoothValue(object): | |||||
def __init__(self, beta: float): | def __init__(self, beta: float): | ||||
self.beta, self.n, self.mov_avg = beta, 0, 0 | self.beta, self.n, self.mov_avg = beta, 0, 0 | ||||
self.smooth = None | self.smooth = None | ||||
def add_value(self, val: float) -> None: | def add_value(self, val: float) -> None: | ||||
"Add `val` to calculate updated smoothed value." | "Add `val` to calculate updated smoothed value." | ||||
self.n += 1 | self.n += 1 | ||||
@@ -421,13 +414,15 @@ class SmoothValue(object): | |||||
class LRFinder(Callback): | class LRFinder(Callback): | ||||
def __init__(self, start_lr=1e-6, end_lr=10): | |||||
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||||
""" | |||||
用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||||
:param int n_batch: 一个epoch内的iteration数 | |||||
:param float start_lr: 学习率下界 | |||||
:param float end_lr: 学习率上界 | |||||
""" | |||||
:param float start_lr: 学习率下界 | |||||
:param float end_lr: 学习率上界 | |||||
""" | |||||
def __init__(self, start_lr=1e-6, end_lr=10): | |||||
super(LRFinder, self).__init__() | super(LRFinder, self).__init__() | ||||
self.start_lr, self.end_lr = start_lr, end_lr | self.start_lr, self.end_lr = start_lr, end_lr | ||||
self.num_it = self.batch_per_epoch | self.num_it = self.batch_per_epoch | ||||
@@ -438,19 +433,19 @@ class LRFinder(Callback): | |||||
self.smooth_value = SmoothValue(0.8) | self.smooth_value = SmoothValue(0.8) | ||||
self.opt = None | self.opt = None | ||||
scale = (self.end_lr - self.start_lr) / self.num_it | scale = (self.end_lr - self.start_lr) / self.num_it | ||||
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) | self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) | ||||
self.find = None | self.find = None | ||||
self.loader = ModelLoader() | self.loader = ModelLoader() | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
if self.epoch == 1: # first epoch | |||||
if self.epoch == 1: # first epoch | |||||
self.opt = self.trainer.optimizer # pytorch optimizer | self.opt = self.trainer.optimizer # pytorch optimizer | ||||
self.opt.param_groups[0]["lr"] = self.start_lr | self.opt.param_groups[0]["lr"] = self.start_lr | ||||
# save model | # save model | ||||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | ||||
self.find = True | self.find = True | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
if self.find: | if self.find: | ||||
if torch.isnan(loss) or self.stop is True: | if torch.isnan(loss) or self.stop is True: | ||||
@@ -462,7 +457,7 @@ class LRFinder(Callback): | |||||
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: | if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: | ||||
self.best_loss = self.smooth_value.smooth | self.best_loss = self.smooth_value.smooth | ||||
self.best_lr = self.opt.param_groups[0]["lr"] | self.best_lr = self.opt.param_groups[0]["lr"] | ||||
def on_batch_end(self, *args): | def on_batch_end(self, *args): | ||||
if self.find: | if self.find: | ||||
lr = next(self.lr_gen, None) | lr = next(self.lr_gen, None) | ||||
@@ -471,9 +466,9 @@ class LRFinder(Callback): | |||||
return | return | ||||
self.opt.param_groups[0]["lr"] = lr | self.opt.param_groups[0]["lr"] = lr | ||||
# self.loader.load_pytorch(self.trainer.model, "tmp") | # self.loader.load_pytorch(self.trainer.model, "tmp") | ||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
if self.epoch == 1: # first epoch | |||||
if self.epoch == 1: # first epoch | |||||
self.opt.param_groups[0]["lr"] = self.best_lr | self.opt.param_groups[0]["lr"] = self.best_lr | ||||
self.find = False | self.find = False | ||||
# reset model | # reset model | ||||
@@ -483,12 +478,12 @@ class LRFinder(Callback): | |||||
class TensorboardCallback(Callback): | class TensorboardCallback(Callback): | ||||
""" | """ | ||||
接受以下一个或多个字符串作为参数: | |||||
- "model" | |||||
- "loss" | |||||
- "metric" | |||||
接受以下一个或多个字符串作为参数: | |||||
- "model" | |||||
- "loss" | |||||
- "metric" | |||||
""" | """ | ||||
def __init__(self, *options): | def __init__(self, *options): | ||||
super(TensorboardCallback, self).__init__() | super(TensorboardCallback, self).__init__() | ||||
args = {"model", "loss", "metric"} | args = {"model", "loss", "metric"} | ||||
@@ -498,7 +493,7 @@ class TensorboardCallback(Callback): | |||||
self.options = options | self.options = options | ||||
self._summary_writer = None | self._summary_writer = None | ||||
self.graph_added = False | self.graph_added = False | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
save_dir = self.trainer.save_path | save_dir = self.trainer.save_path | ||||
if save_dir is None: | if save_dir is None: | ||||
@@ -506,7 +501,7 @@ class TensorboardCallback(Callback): | |||||
else: | else: | ||||
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | ||||
self._summary_writer = SummaryWriter(path) | self._summary_writer = SummaryWriter(path) | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
if "model" in self.options and self.graph_added is False: | if "model" in self.options and self.graph_added is False: | ||||
# tesorboardX 这里有大bug,暂时没法画模型图 | # tesorboardX 这里有大bug,暂时没法画模型图 | ||||
@@ -516,11 +511,11 @@ class TensorboardCallback(Callback): | |||||
# args = args[0] if len(args) == 1 else args | # args = args[0] if len(args) == 1 else args | ||||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | ||||
self.graph_added = True | self.graph_added = True | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
if "loss" in self.options: | if "loss" in self.options: | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | ||||
if "model" in self.options: | if "model" in self.options: | ||||
for name, param in self.trainer.model.named_parameters(): | for name, param in self.trainer.model.named_parameters(): | ||||
if param.requires_grad: | if param.requires_grad: | ||||
@@ -528,21 +523,40 @@ class TensorboardCallback(Callback): | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step) | # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step) | ||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | ||||
global_step=self.trainer.step) | global_step=self.trainer.step) | ||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
if "metric" in self.options: | if "metric" in self.options: | ||||
for name, metric in eval_result.items(): | for name, metric in eval_result.items(): | ||||
for metric_key, metric_val in metric.items(): | for metric_key, metric_val in metric.items(): | ||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | ||||
global_step=self.trainer.step) | global_step=self.trainer.step) | ||||
def on_train_end(self): | def on_train_end(self): | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
if hasattr(self, "_summary_writer"): | if hasattr(self, "_summary_writer"): | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
class CallbackException(BaseException): | |||||
""" | |||||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | |||||
:param str msg: Exception的信息。 | |||||
""" | |||||
def __init__(self, msg): | |||||
super(CallbackException, self).__init__(msg) | |||||
class EarlyStopError(CallbackException): | |||||
""" | |||||
用于EarlyStop时从Trainer训练循环中跳出。 | |||||
""" | |||||
def __init__(self, msg): | |||||
super(EarlyStopError, self).__init__(msg) |
@@ -1,4 +1,4 @@ | |||||
class Const(): | |||||
class Const: | |||||
"""fastNLP中field命名常量。 | """fastNLP中field命名常量。 | ||||
具体列表:: | 具体列表:: | ||||
@@ -1,5 +1,5 @@ | |||||
""" | """ | ||||
DataSet是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格, | |||||
:class:`~fastNLP.core.dataset.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格, | |||||
每一行是一个sample (在fastNLP中被称为 :mod:`~.instance` ), | 每一行是一个sample (在fastNLP中被称为 :mod:`~.instance` ), | ||||
每一列是一个feature (在fastNLP中称为 :mod:`.field` )。 | 每一列是一个feature (在fastNLP中称为 :mod:`.field` )。 | ||||
@@ -294,7 +294,8 @@ class DataSet(object): | |||||
fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` | fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` | ||||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | :param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | ||||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||||
""" | """ | ||||
def __init__(self, data=None): | def __init__(self, data=None): | ||||
@@ -1,36 +1,34 @@ | |||||
""" | """ | ||||
.. _LossBase: | |||||
.. _Loss: | |||||
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||||
""" | """ | ||||
__all__ = ["LossBase", "L1Loss", "LossFunc", "LossInForward", "BCELoss", "CrossEntropyLoss", "NLLLoss"] | |||||
import inspect | import inspect | ||||
from collections import defaultdict | from collections import defaultdict | ||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _CheckRes | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _check_function_or_method | |||||
from fastNLP.core.utils import _get_func_signature | |||||
from .utils import _CheckError | |||||
from .utils import _CheckRes | |||||
from .utils import _build_args | |||||
from .utils import _check_arg_dict_list | |||||
from .utils import _check_function_or_method | |||||
from .utils import _get_func_signature | |||||
class LossBase(object): | class LossBase(object): | ||||
"""所有loss的基类. | |||||
""" | """ | ||||
所有loss的基类。如果想了解其中的原理,请查看源码。 | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
self.param_map = {} | self.param_map = {} | ||||
self._checked = False | self._checked = False | ||||
def get_loss(self, *args, **kwargs): | def get_loss(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | """检查key_map和其他参数map,并将这些映射关系添加到self.param_map | ||||
@@ -63,7 +61,7 @@ class LossBase(object): | |||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and param_map | # check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
@@ -72,12 +70,12 @@ class LossBase(object): | |||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | ||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
# evaluate should not have varargs. | # evaluate should not have varargs. | ||||
# if func_spect.varargs: | # if func_spect.varargs: | ||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | ||||
# f"positional argument.).") | # f"positional argument.).") | ||||
def _fast_param_map(self, pred_dict, target_dict): | def _fast_param_map(self, pred_dict, target_dict): | ||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | ||||
such as pred_dict has one element, target_dict has one element | such as pred_dict has one element, target_dict has one element | ||||
@@ -92,7 +90,7 @@ class LossBase(object): | |||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
""" | """ | ||||
:param dict pred_dict: 模型的forward函数返回的dict | :param dict pred_dict: 模型的forward函数返回的dict | ||||
@@ -104,7 +102,7 @@ class LossBase(object): | |||||
if fast_param: | if fast_param: | ||||
loss = self.get_loss(**fast_param) | loss = self.get_loss(**fast_param) | ||||
return loss | return loss | ||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
@@ -112,14 +110,14 @@ class LossBase(object): | |||||
for func_arg, input_arg in self.param_map.items(): | for func_arg, input_arg in self.param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | ||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | if arg not in self.param_map: | ||||
self.param_map[arg] = arg # This param does not need mapping. | self.param_map[arg] = arg # This param does not need mapping. | ||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | ||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
@@ -139,7 +137,7 @@ class LossBase(object): | |||||
not_duplicate_flag += 1 | not_duplicate_flag += 1 | ||||
if not_duplicate_flag == 3: | if not_duplicate_flag == 3: | ||||
duplicated.append(input_arg) | duplicated.append(input_arg) | ||||
# missing | # missing | ||||
if not self._checked: | if not self._checked: | ||||
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) | check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) | ||||
@@ -149,47 +147,50 @@ class LossBase(object): | |||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
unused=check_res.unused, | unused=check_res.unused, | ||||
duplicated=duplicated, | duplicated=duplicated, | ||||
required=check_res.required, | required=check_res.required, | ||||
all_needed=check_res.all_needed, | all_needed=check_res.all_needed, | ||||
varargs=check_res.varargs) | varargs=check_res.varargs) | ||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.get_loss)) | func_signature=_get_func_signature(self.get_loss)) | ||||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | ||||
loss = self.get_loss(**refined_args) | loss = self.get_loss(**refined_args) | ||||
self._checked = True | self._checked = True | ||||
return loss | return loss | ||||
class LossFunc(LossBase): | class LossFunc(LossBase): | ||||
"""提供给用户使用自定义损失函数的类 | |||||
""" | """ | ||||
def __init__(self, func, key_map=None, **kwargs): | |||||
""" | |||||
别名::class:`fastNLP.LossFunc` :class:`fastNLP.core.losses.LossFunc` | |||||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | |||||
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。 | |||||
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中 | |||||
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数 | |||||
:param kwargs: 除了参数映射表以外可以用key word args的方式设置参数映射关系 | |||||
提供给用户使用自定义损失函数的类 | |||||
Example:: | |||||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | |||||
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。 | |||||
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中 | |||||
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数 | |||||
:param kwargs: 除了参数映射表以外可以用key word args的方式设置参数映射关系 | |||||
>>> func = torch.nn.CrossEntropyLoss() | |||||
>>> loss_func = LossFunc(func, input="pred", target="label") | |||||
>>> # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field | |||||
>>> # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数 | |||||
>>> # 传入func作为一个名为`target`的参数 | |||||
Example:: | |||||
""" | |||||
>>> func = torch.nn.CrossEntropyLoss() | |||||
>>> loss_func = LossFunc(func, input="pred", target="label") | |||||
# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field | |||||
# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数 | |||||
# 传入func作为一个名为`target`的参数 | |||||
""" | |||||
def __init__(self, func, key_map=None, **kwargs): | |||||
super(LossFunc, self).__init__() | super(LossFunc, self).__init__() | ||||
_check_function_or_method(func) | _check_function_or_method(func) | ||||
if key_map is not None: | if key_map is not None: | ||||
@@ -199,94 +200,108 @@ class LossFunc(LossBase): | |||||
if len(kwargs) > 0: | if len(kwargs) > 0: | ||||
for key, val in kwargs.items(): | for key, val in kwargs.items(): | ||||
self.param_map.update({key: val}) | self.param_map.update({key: val}) | ||||
self.get_loss = func | self.get_loss = func | ||||
class CrossEntropyLoss(LossBase): | class CrossEntropyLoss(LossBase): | ||||
""" | """ | ||||
.. _CrossEntropyLoss: | |||||
别名::class:`fastNLP.CrossEntropyLoss` :class:`fastNLP.core.losses.CrossEntropyLoss` | |||||
交叉熵损失函数""" | |||||
def __init__(self, pred=None, target=None, padding_idx=-100): | |||||
""" | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 | |||||
交叉熵损失函数 | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 | |||||
Example:: | |||||
Example:: | |||||
>>> loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0) | |||||
""" | |||||
>>> loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0) | |||||
""" | |||||
def __init__(self, pred=None, target=None, padding_idx=-100): | |||||
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 | # TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 | ||||
# TODO (16, 4) | # TODO (16, 4) | ||||
super(CrossEntropyLoss, self).__init__() | super(CrossEntropyLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
self.padding_idx = padding_idx | self.padding_idx = padding_idx | ||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.cross_entropy(input=pred, target=target, | return F.cross_entropy(input=pred, target=target, | ||||
ignore_index=self.padding_idx) | ignore_index=self.padding_idx) | ||||
class L1Loss(LossBase): | class L1Loss(LossBase): | ||||
"""L1损失函数""" | |||||
""" | |||||
别名::class:`fastNLP.L1Loss` :class:`fastNLP.core.losses.L1Loss` | |||||
L1损失函数 | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | |||||
""" | |||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
""" | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||||
""" | |||||
super(L1Loss, self).__init__() | super(L1Loss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.l1_loss(input=pred, target=target) | return F.l1_loss(input=pred, target=target) | ||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
"""二分类交叉熵损失函数""" | |||||
""" | |||||
别名::class:`fastNLP.BCELoss` :class:`fastNLP.core.losses.BCELoss` | |||||
二分类交叉熵损失函数 | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||||
""" | |||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
""" | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||||
""" | |||||
super(BCELoss, self).__init__() | super(BCELoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.binary_cross_entropy(input=pred, target=target) | return F.binary_cross_entropy(input=pred, target=target) | ||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
"""负对数似然损失函数""" | |||||
""" | |||||
别名::class:`fastNLP.NLLLoss` :class:`fastNLP.core.losses.NLLLoss` | |||||
负对数似然损失函数 | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||||
""" | |||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
""" | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||||
""" | |||||
super(NLLLoss, self).__init__() | super(NLLLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.nll_loss(input=pred, target=target) | return F.nll_loss(input=pred, target=target) | ||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
""" | """ | ||||
.. _LossInForward: | |||||
别名::class:`fastNLP.LossInForward` :class:`fastNLP.core.losses.LossInForward` | |||||
从forward()函数返回结果中获取loss | 从forward()函数返回结果中获取loss | ||||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||||
""" | """ | ||||
def __init__(self, loss_key='loss'): | def __init__(self, loss_key='loss'): | ||||
""" | |||||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
if not isinstance(loss_key, str): | if not isinstance(loss_key, str): | ||||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | ||||
self.loss_key = loss_key | self.loss_key = loss_key | ||||
def get_loss(self, **kwargs): | def get_loss(self, **kwargs): | ||||
if self.loss_key not in kwargs: | if self.loss_key not in kwargs: | ||||
check_res = _CheckRes( | check_res = _CheckRes( | ||||
@@ -298,17 +313,17 @@ class LossInForward(LossBase): | |||||
varargs=[]) | varargs=[]) | ||||
raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss)) | raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss)) | ||||
return kwargs[self.loss_key] | return kwargs[self.loss_key] | ||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
loss = self.get_loss(**pred_dict) | loss = self.get_loss(**pred_dict) | ||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | ||||
loss = torch.sum(loss) / (loss.view(-1)).size(0) | loss = torch.sum(loss) / (loss.view(-1)).size(0) | ||||
# raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | # raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | ||||
return loss | return loss | ||||
@@ -378,13 +393,13 @@ def mask(predict, truth, **kwargs): | |||||
if kwargs.get("mask") is None: | if kwargs.get("mask") is None: | ||||
return predict, truth | return predict, truth | ||||
mask = kwargs["mask"] | mask = kwargs["mask"] | ||||
predict, truth = squash(predict, truth) | predict, truth = squash(predict, truth) | ||||
mask = mask.view(-1, ) | mask = mask.view(-1, ) | ||||
predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0) | predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0) | ||||
truth = torch.masked_select(truth, mask) | truth = torch.masked_select(truth, mask) | ||||
return predict, truth | return predict, truth | ||||
@@ -399,4 +414,3 @@ def make_mask(lens, tar_len): | |||||
mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | ||||
mask = torch.stack(mask, 1) | mask = torch.stack(mask, 1) | ||||
return mask | return mask | ||||
@@ -1,31 +1,25 @@ | |||||
""" | """ | ||||
.. _Metric: | |||||
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||||
""" | """ | ||||
import inspect | import inspect | ||||
from collections import defaultdict | from collections import defaultdict | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _CheckRes | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _get_func_signature | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from .utils import _CheckError | |||||
from .utils import _CheckRes | |||||
from .utils import _build_args | |||||
from .utils import _check_arg_dict_list | |||||
from .utils import _get_func_signature | |||||
from .utils import seq_lens_to_masks | |||||
from .vocabulary import Vocabulary | |||||
class MetricBase(object): | class MetricBase(object): | ||||
"""所有metrics的基类 | |||||
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | |||||
""" | |||||
所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | |||||
evaluate(xxx)中传入的是一个batch的数据。 | evaluate(xxx)中传入的是一个batch的数据。 | ||||
@@ -94,17 +88,17 @@ class MetricBase(object): | |||||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | ||||
``MetricBase`` 将会在输入的字典``pred_dict``和``target_dict``中进行检查. | |||||
``pred_dict`` 是模型当中``forward()``函数或者``predict()``函数的返回值. | |||||
``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的``is_target``被设置为True. | |||||
``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查. | |||||
``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值. | |||||
``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True. | |||||
``MetricBase`` 会进行以下的类型检测: | ``MetricBase`` 会进行以下的类型检测: | ||||
1. self.evaluate当中是否有varargs, 这是不支持的. | 1. self.evaluate当中是否有varargs, 这是不支持的. | ||||
2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``. | |||||
3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``. | |||||
2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` . | |||||
3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` . | |||||
除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数 | |||||
除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数 | |||||
如果kwargs是self.evaluate的参数,则不会检测 | 如果kwargs是self.evaluate的参数,则不会检测 | ||||
@@ -267,13 +261,18 @@ class MetricBase(object): | |||||
class AccuracyMetric(MetricBase): | class AccuracyMetric(MetricBase): | ||||
"""准确率Metric""" | |||||
""" | |||||
别名::class:`fastNLP.AccuracyMetric` :class:`fastNLP.core.metrics.AccuracyMetric` | |||||
准确率Metric(其它的Metric参见 :doc:`fastNLP.core.metrics` ) | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||||
:param seq_len: 参数映射表中 `seq_lens` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||||
""" | |||||
def __init__(self, pred=None, target=None, seq_len=None): | def __init__(self, pred=None, target=None, seq_len=None): | ||||
""" | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | |||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | |||||
:param seq_len: 参数映射表中`seq_lens`的映射关系,None表示映射关系为`seq_len`->`seq_len` | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
@@ -282,7 +281,8 @@ class AccuracyMetric(MetricBase): | |||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
""" | |||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | ||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | ||||
@@ -327,7 +327,8 @@ class AccuracyMetric(MetricBase): | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||||
""" | |||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | :param bool reset: 在调用完get_metric后是否清空评价指标统计量. | ||||
:return dict evaluate_result: {"acc": float} | :return dict evaluate_result: {"acc": float} | ||||
@@ -430,8 +431,6 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
""" | """ | ||||
.. _SpanFPreRecMetric: | |||||
在序列标注问题中,以span的方式计算F, pre, rec. | 在序列标注问题中,以span的方式计算F, pre, rec. | ||||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | ||||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | ||||
@@ -455,26 +454,24 @@ class SpanFPreRecMetric(MetricBase): | |||||
... | ... | ||||
} | } | ||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 | |||||
:param str encoding_type: 目前支持bio, bmes | |||||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||||
个label | |||||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||||
label的f1, pre, rec | |||||
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | ||||
only_gross=True, f_type='micro', beta=1): | only_gross=True, f_type='micro', beta=1): | ||||
""" | |||||
:param Vocabulary tag_vocab: 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 | |||||
:param str encoding_type: 目前支持bio, bmes | |||||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||||
个label | |||||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||||
label的f1, pre, rec | |||||
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
""" | |||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if not isinstance(tag_vocab, Vocabulary): | if not isinstance(tag_vocab, Vocabulary): | ||||
@@ -647,20 +644,18 @@ class BMESF1PreRecMetric(MetricBase): | |||||
target形状为 (batch_size, max_len) | target形状为 (batch_size, max_len) | ||||
seq_lens形状为 (batch_size, ) | seq_lens形状为 (batch_size, ) | ||||
""" | |||||
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 | |||||
:param b_idx: int, Begin标签所对应的tag idx. | |||||
:param m_idx: int, Middle标签所对应的tag idx. | |||||
:param e_idx: int, End标签所对应的tag idx. | |||||
:param s_idx: int, Single标签所对应的tag idx | |||||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。 | |||||
""" | |||||
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None): | def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None): | ||||
""" | |||||
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 | |||||
:param b_idx: int, Begin标签所对应的tag idx. | |||||
:param m_idx: int, Middle标签所对应的tag idx. | |||||
:param e_idx: int, End标签所对应的tag idx. | |||||
:param s_idx: int, Single标签所对应的tag idx | |||||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。 | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
@@ -831,21 +826,23 @@ def _pred_topk(y_prob, k=1): | |||||
class SQuADMetric(MetricBase): | class SQuADMetric(MetricBase): | ||||
"""SQuAD数据集metric | |||||
""" | |||||
SQuAD数据集metric | |||||
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | |||||
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` | |||||
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` | |||||
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` | |||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||||
""" | """ | ||||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | ||||
beta=1, right_open=True, print_predict_stat=False): | beta=1, right_open=True, print_predict_stat=False): | ||||
""" | |||||
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | |||||
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` | |||||
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` | |||||
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` | |||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||||
""" | |||||
super(SQuADMetric, self).__init__() | super(SQuADMetric, self).__init__() | ||||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | ||||
@@ -1,11 +1,16 @@ | |||||
""" | |||||
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||||
""" | |||||
import torch | import torch | ||||
class Optimizer(object): | class Optimizer(object): | ||||
""" | """ | ||||
别名::class:`fastNLP.Optimizer` :class:`fastNLP.core.optimizer.Optimizer` | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||||
:param kwargs: additional parameters. | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||||
:param kwargs: additional parameters. | |||||
""" | """ | ||||
def __init__(self, model_params, **kwargs): | def __init__(self, model_params, **kwargs): | ||||
if model_params is not None and not hasattr(model_params, "__next__"): | if model_params is not None and not hasattr(model_params, "__next__"): | ||||
@@ -26,10 +31,11 @@ class Optimizer(object): | |||||
class SGD(Optimizer): | class SGD(Optimizer): | ||||
""" | """ | ||||
别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` | |||||
:param float lr: learning rate. Default: 0.01 | |||||
:param float momentum: momentum. Default: 0 | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||||
:param float lr: learning rate. Default: 0.01 | |||||
:param float momentum: momentum. Default: 0 | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||||
""" | """ | ||||
def __init__(self, lr=0.001, momentum=0, model_params=None): | def __init__(self, lr=0.001, momentum=0, model_params=None): | ||||
@@ -47,10 +53,11 @@ class SGD(Optimizer): | |||||
class Adam(Optimizer): | class Adam(Optimizer): | ||||
""" | """ | ||||
别名::class:`fastNLP.Adam` :class:`fastNLP.core.optimizer.Adam` | |||||
:param float lr: learning rate | |||||
:param float weight_decay: | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||||
:param float lr: learning rate | |||||
:param float weight_decay: | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||||
""" | """ | ||||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | ||||
@@ -2,10 +2,10 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
from fastNLP.core import Batch | |||||
from fastNLP.core import DataSet | |||||
from fastNLP.core import SequentialSampler | |||||
from fastNLP.core.utils import _build_args | |||||
from . import Batch | |||||
from . import DataSet | |||||
from . import SequentialSampler | |||||
from .utils import _build_args | |||||
class Predictor(object): | class Predictor(object): | ||||
@@ -1,22 +1,24 @@ | |||||
""" | """ | ||||
sampler 子类实现了 fastNLP 所需的各种采样器。 | |||||
.. _Sampler: | |||||
""" | """ | ||||
__all__ = ["Sampler", "BucketSampler", "SequentialSampler", "RandomSampler"] | |||||
from itertools import chain | from itertools import chain | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
class Sampler(object): | class Sampler(object): | ||||
""" `Sampler` 类的基类. 规定以何种顺序取出data中的元素 | |||||
""" | |||||
别名::class:`fastNLP.Sampler` :class:`fastNLP.core.sampler.Sampler` | |||||
`Sampler` 类的基类. 规定以何种顺序取出data中的元素 | |||||
子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | 子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | ||||
""" | """ | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
""" | """ | ||||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | :param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | ||||
@@ -26,56 +28,62 @@ class Sampler(object): | |||||
class SequentialSampler(Sampler): | class SequentialSampler(Sampler): | ||||
"""顺序取出元素的 `Sampler` | |||||
.. _SequentialSampler: | |||||
""" | |||||
别名::class:`fastNLP.SequentialSampler` :class:`fastNLP.core.sampler.SequentialSampler` | |||||
顺序取出元素的 `Sampler` | |||||
""" | """ | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
return list(range(len(data_set))) | return list(range(len(data_set))) | ||||
class RandomSampler(Sampler): | class RandomSampler(Sampler): | ||||
""" | """ | ||||
.. _RandomSampler: | |||||
别名::class:`fastNLP.RandomSampler` :class:`fastNLP.core.sampler.RandomSampler` | |||||
随机化取元素的 `Sampler` | 随机化取元素的 `Sampler` | ||||
""" | """ | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
return list(np.random.permutation(len(data_set))) | return list(np.random.permutation(len(data_set))) | ||||
class BucketSampler(Sampler): | class BucketSampler(Sampler): | ||||
"""带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | |||||
""" | |||||
别名::class:`fastNLP.BucketSampler` :class:`fastNLP.core.sampler.BucketSampler` | |||||
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | |||||
:param int num_buckets: bucket的数量 | :param int num_buckets: bucket的数量 | ||||
:param int batch_size: batch的大小 | :param int batch_size: batch的大小 | ||||
:param str seq_lens_field_name: 对应序列长度的 `field` 的名字 | :param str seq_lens_field_name: 对应序列长度的 `field` 的名字 | ||||
""" | """ | ||||
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_len'): | def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_len'): | ||||
self.num_buckets = num_buckets | self.num_buckets = num_buckets | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.seq_lens_field_name = seq_lens_field_name | self.seq_lens_field_name = seq_lens_field_name | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content | seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content | ||||
total_sample_num = len(seq_lens) | total_sample_num = len(seq_lens) | ||||
bucket_indexes = [] | bucket_indexes = [] | ||||
assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets." | |||||
assert total_sample_num >= self.num_buckets, "The number of samples is smaller than the number of buckets." | |||||
num_sample_per_bucket = total_sample_num // self.num_buckets | num_sample_per_bucket = total_sample_num // self.num_buckets | ||||
for i in range(self.num_buckets): | for i in range(self.num_buckets): | ||||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | ||||
bucket_indexes[-1][1] = total_sample_num | bucket_indexes[-1][1] = total_sample_num | ||||
sorted_seq_lens = list(sorted([(idx, seq_len) for | sorted_seq_lens = list(sorted([(idx, seq_len) for | ||||
idx, seq_len in zip(range(total_sample_num), seq_lens)], | idx, seq_len in zip(range(total_sample_num), seq_lens)], | ||||
key=lambda x: x[1])) | key=lambda x: x[1])) | ||||
batchs = [] | batchs = [] | ||||
left_init_indexes = [] | left_init_indexes = [] | ||||
for b_idx in range(self.num_buckets): | for b_idx in range(self.num_buckets): | ||||
start_idx = bucket_indexes[b_idx][0] | start_idx = bucket_indexes[b_idx][0] | ||||
@@ -90,7 +98,7 @@ class BucketSampler(Sampler): | |||||
if (left_init_indexes) != 0: | if (left_init_indexes) != 0: | ||||
batchs.append(left_init_indexes) | batchs.append(left_init_indexes) | ||||
np.random.shuffle(batchs) | np.random.shuffle(batchs) | ||||
return list(chain(*batchs)) | return list(chain(*batchs)) | ||||
@@ -128,10 +136,10 @@ def k_means_1d(x, k, max_iter=100): | |||||
if len(sorted_x) < k: | if len(sorted_x) < k: | ||||
raise ValueError("too few buckets") | raise ValueError("too few buckets") | ||||
gap = len(sorted_x) / k | gap = len(sorted_x) / k | ||||
centroids = np.array([sorted_x[int(x * gap)] for x in range(k)]) | centroids = np.array([sorted_x[int(x * gap)] for x in range(k)]) | ||||
assign = None | assign = None | ||||
for i in range(max_iter): | for i in range(max_iter): | ||||
# Cluster Assignment step | # Cluster Assignment step | ||||
assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x]) | assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x]) | ||||
@@ -163,7 +171,7 @@ def k_means_bucketing(lengths, buckets): | |||||
bucket_data = [[] for _ in buckets] | bucket_data = [[] for _ in buckets] | ||||
num_buckets = len(buckets) | num_buckets = len(buckets) | ||||
_, assignments = k_means_1d(lengths, num_buckets) | _, assignments = k_means_1d(lengths, num_buckets) | ||||
for idx, bucket_id in enumerate(assignments): | for idx, bucket_id in enumerate(assignments): | ||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | ||||
bucket_data[bucket_id].append(idx) | bucket_data[bucket_id].append(idx) | ||||
@@ -1,81 +1,81 @@ | |||||
import torch | |||||
from torch import nn | |||||
""" | |||||
tester模块实现了 fastNLP 所需的Tester类,能在提供数据、模型以及metric的情况下进行性能测试。 | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import _get_func_signature | |||||
from fastNLP.core.utils import _get_model_device | |||||
from fastNLP.core.utils import _move_model_to_device | |||||
Example:: | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
from fastNLP import Tester | |||||
from fastNLP import DataSet | |||||
from fastNLP import AccuracyMetric | |||||
class Tester(object): | |||||
""" | |||||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类 | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, a): | |||||
return {'pred': self.fc(a.unsqueeze(1)).squeeze(1)} | |||||
Example:: | |||||
model = Model() | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
from fastNLP import Tester | |||||
from fastNLP import DataSet | |||||
from fastNLP import AccuracyMetric | |||||
dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) | |||||
dataset.set_input('a') | |||||
dataset.set_target('b') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, a): | |||||
return {'pred': self.fc(a.unsqueeze(1)).squeeze(1)} | |||||
tester = Tester(dataset, model, metrics=AccuracyMetric()) | |||||
eval_results = tester.test() | |||||
model = Model() | |||||
这里Metric的映射规律是和 :class:`fastNLP.Trainer` 中一致的,具体使用请参考 :doc:`trainer 模块<fastNLP.core.trainer>` 的1.3部分 | |||||
dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) | |||||
dataset.set_input('a') | |||||
dataset.set_target('b') | |||||
tester = Tester(dataset, model, metrics=AccuracyMetric()) | |||||
eval_results = tester.test() | |||||
这里Metric的映射规律是和 Trainer_ 中一致的,请参考 Trainer_ 使用metrics。 | |||||
""" | |||||
import torch | |||||
from torch import nn | |||||
from .batch import Batch | |||||
from .dataset import DataSet | |||||
from .metrics import _prepare_metrics | |||||
from .sampler import SequentialSampler | |||||
from .utils import _CheckError | |||||
from .utils import _build_args | |||||
from .utils import _check_loss_evaluate | |||||
from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | |||||
from .utils import _get_model_device | |||||
from .utils import _move_model_to_device | |||||
class Tester(object): | |||||
""" | """ | ||||
别名::class:`fastNLP.Tester` :class:`fastNLP.core.tester.Tester` | |||||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||||
"""传入模型,数据以及metric进行验证。 | |||||
:param DataSet data: 需要测试的数据集 | |||||
:param torch.nn.module model: 使用的模型 | |||||
:param MetricBase metrics: 一个Metric或者一个列表的metric对象 | |||||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||||
的计算位置进行管理。支持以下的输入: | |||||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 | |||||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||||
可见的第二个GPU中; | |||||
:param data: 需要测试的数据集, :class:`~fastNLP.DataSet` 类型 | |||||
:param torch.nn.module model: 使用的模型 | |||||
:param metrics: :class:`~fastNLP.core.metrics.MetricBase` 或者一个列表的 :class:`~fastNLP.core.metrics.MetricBase` | |||||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||||
的计算位置进行管理。支持以下的输入: | |||||
2. torch.device:将模型装载到torch.device上。 | |||||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||||
可见的第二个GPU中; | |||||
3. int: 将使用device_id为该值的gpu进行训练 | |||||
2. torch.device:将模型装载到torch.device上。 | |||||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||||
3. int: 将使用device_id为该值的gpu进行训练 | |||||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||||
""" | |||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||||
""" | |||||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
@@ -103,7 +103,7 @@ class Tester(object): | |||||
def test(self): | def test(self): | ||||
"""开始进行验证,并返回验证结果。 | """开始进行验证,并返回验证结果。 | ||||
:return dict(dict) eval_results: dict为二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。 | |||||
:return Dict[Dict] : dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。 | |||||
一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | 一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | ||||
""" | """ | ||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
@@ -1,13 +1,17 @@ | |||||
""" | """ | ||||
Trainer的说明文档 | |||||
.. _Trainer: | |||||
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 (1) epoch循环; (2) 将数据分成不同的Batch; (3) | |||||
对Batch进行pad; (4) 每个epoch结束或一定step后进行验证集验证; (5) 保存获得更好验证性能的模型等。 | |||||
1. Trainer的基本使用 | |||||
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰以下步骤的代码 | |||||
(1) epoch循环; | |||||
(2) 将数据分成不同的Batch; | |||||
(3) 对Batch进行pad; | |||||
(4) 每个epoch结束或一定step后进行验证集验证; | |||||
(5) 保存获得更好验证性能的模型。 | |||||
1 Trainer的基本使用 | |||||
下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。 | 下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。 | ||||
Example:: | Example:: | ||||
@@ -20,8 +24,8 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP import CrossEntropyLoss | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP.modules.decoder import MLP | from fastNLP.modules.decoder import MLP | ||||
# 模型 | # 模型 | ||||
@@ -56,208 +60,214 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。 | 由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。 | ||||
使用Trainer需要满足以下几个条件: | 使用Trainer需要满足以下几个条件: | ||||
1. 模型 | |||||
1.1 模型 | |||||
1 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是 | |||||
通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该 | |||||
改名为'data'。 | |||||
1. 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是 | |||||
通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该 | |||||
改名为'data'。 | |||||
2 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递 | |||||
给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。 | |||||
2. 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递 | |||||
给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。 | |||||
3 模型的forward()返回值需要为一个dict。 | |||||
3. 模型的forward()返回值需要为一个dict。 | |||||
1.2 Loss | |||||
fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing, | |||||
:mod:`Loss<fastNLP.core.losses>` 与 :mod:`Metric<fastNLP.core.metrics>` 都使用了通过名称来匹配相应内容的策略。如上面的例子中 | |||||
2. Loss | |||||
Example:: | |||||
fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing, Loss_ 与 Metric_ 都使 | |||||
用了通过名称来匹配相应内容的策略。如上面的例子中 | |||||
trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'), | |||||
optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000, | |||||
dev_data = dev_data, metrics=AccuracyMetric(target='label')) | |||||
loss被设置为了 :class:`~fastNLP.CrossEntropyLoss` , 但在初始化的时候传入了target='label'这个参数, | |||||
:class:`~fastNLP.CrossEntropyLoss` 的初始化参数为(pred=None, target=None, padding_idx=-100)。 | |||||
这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值与真实值。 | |||||
其中 `pred` 一般来自于模型forward()的返回结果,`target` 一般是来自于DataSet中被设置为target的field。 | |||||
由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 :mod:`Loss<fastNLP.core.losses>` 提供一种类似于映射的机制来匹配对应的值, | |||||
比如这里 :class:`~fastNLP.CrossEntropyLoss` 将尝试找到名为'label'的内容来作为真实值得到loss; | |||||
而pred=None, 则 :class:`~fastNLP.CrossEntropyLoss` 使用'pred'作为名称匹配预测值, | |||||
正好forward的返回值也叫pred,所以这里不需要申明pred。 | |||||
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。 | |||||
fastNLP中提供了 :class:`~fastNLP.LossInForward` 这个loss。 | |||||
这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,并使用它作为loss。 | |||||
如果Trainer初始化没有提供loss则默认使用 :class:`~fastNLP.LossInForward` 。TODO 补充一个例子 详细例子可以参照 | |||||
1.3 Metric | |||||
:mod:`Metric<fastNLP.core.metrics>` 使用了与上述Loss一样的策略,即使用名称进行匹配。 | |||||
AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。 | |||||
在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法, | |||||
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果, | |||||
传入到predict()的参数也是从DataSet中被设置为input的field中选择出来的; | |||||
与forward()一样,返回值需要为一个dict。 TODO 补充一个例子 具体例子可以参考 | |||||
Example:: | |||||
2 Trainer的代码检查 | |||||
由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制 | |||||
比如下面的例子中,由于各种原因产生的报错 | |||||
Example2.1 | |||||
:: | |||||
import numpy as np | |||||
from torch import nn | |||||
import torch | |||||
from torch.optim import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP import DataSet | |||||
trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'), | |||||
optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000, | |||||
dev_data = dev_data, metrics=AccuracyMetric(target='label')) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, x, b): | |||||
loss = torch.mean((self.fc(x)-b)**2) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
loss被设置为了 CrossEntropyLoss_ , 但在初始化的时候传入了target='label'这个参数, CrossEntropyLoss_ 的初始化 | |||||
参数为(pred=None, target=None, padding_idx=-100)。这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值 | |||||
与真实值。其中'pred'一般来自于模型forward()的返回结果,'target'一般是来自于DataSet中被设置为target的 | |||||
field。由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 Loss_ 提供一种类似于映射的机制来匹配 | |||||
对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_ | |||||
使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。 | |||||
dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) | |||||
dataset.set_input('a', 'b') | |||||
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 LossInForward_ 这 | |||||
个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor, | |||||
并使用它作为loss。 如果Trainer初始化没有提供loss则默认使用 LossInForward_ 。详细例子可以参照 TODO 补充一个例子 | |||||
trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001)) | |||||
3. Metric | |||||
trainer = Trainer(dataset, model, SGD(model.parameters())) | |||||
# 会报以下的错误 | |||||
# input fields after batch(if batch size is 2): | |||||
# a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) | |||||
# b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) | |||||
# There is no target field. | |||||
# .... | |||||
# NameError: | |||||
# Problems occurred when calling Model.forward(self, x, b) | |||||
# missing param: ['x'] | |||||
# unused field: ['a'] | |||||
# Suggestion: You need to provide ['x'] in DataSet and set it as input. | |||||
Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。 | |||||
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类 | |||||
信息可以为你提供参考 | |||||
在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法, | |||||
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中被设置为input | |||||
的field中选择出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子 | |||||
1 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里 | |||||
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。 | |||||
2. Trainer的代码检查 | |||||
2 NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断 | |||||
出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能 | |||||
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。 | |||||
由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制 | |||||
比如下面的例子中,由于各种原因产生的报错 | |||||
下面的例子是由于loss计算的时候找不到需要的值 | |||||
Example1:: | |||||
import numpy as np | |||||
from torch import nn | |||||
import torch | |||||
from torch.optim import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP import DataSet | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, x, b): | |||||
loss = torch.mean((self.fc(x)-b)**2) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) | |||||
dataset.set_input('a', 'b') | |||||
trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001)) | |||||
trainer = Trainer(dataset, model, SGD(model.parameters())) | |||||
# 会报以下的错误 | |||||
# input fields after batch(if batch size is 2): | |||||
# a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) | |||||
# b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) | |||||
# There is no target field. | |||||
# .... | |||||
# NameError: | |||||
# Problems occurred when calling Model.forward(self, x, b) | |||||
# missing param: ['x'] | |||||
# unused field: ['a'] | |||||
# Suggestion: You need to provide ['x'] in DataSet and set it as input. | |||||
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类 | |||||
信息可以为你提供参考 | |||||
1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里 | |||||
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。 | |||||
2. NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断 | |||||
出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能 | |||||
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。 | |||||
下面的例子是由于loss计算的时候找不到需要的值 | |||||
Example2:: | |||||
import numpy as np | |||||
from torch import nn | |||||
from torch.optim import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP import DataSet | |||||
from fastNLP.core.losses import L1Loss | |||||
import torch | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, a): | |||||
return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1} | |||||
model = Model() | |||||
dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) | |||||
dataset.set_input('a') | |||||
dataset.set_target('b') | |||||
trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001)) | |||||
# 报错信息如下 | |||||
# input fields after batch(if batch size is 2): | |||||
# a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2]) | |||||
# target fields after batch(if batch size is 2): | |||||
# b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2]) | |||||
# .... | |||||
# NameError: | |||||
# Problems occurred when calling L1Loss.get_loss(self, pred, target) | |||||
# missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)'] | |||||
# unused field: ['b'] | |||||
# unused param: ['pred_b', 'No use'] | |||||
# target field: ['b'] | |||||
# param from Model.forward(self, a): ['pred_b', 'No use'] | |||||
# Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a). | |||||
# (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a). | |||||
报错信息也包含两部分: | |||||
1. 第一部分与上面是一样的 | |||||
2. 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的);报错的原因是因为 | |||||
`pred`和`label`(我们在初始化L1Loss时将target指定为了label)都没有找到。这里'unused field'是DataSet中出现了,但却没有 | |||||
被设置为input或者target的field;'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了 | |||||
target的field; 'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。 | |||||
但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred, 将 | |||||
DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。 | |||||
下面是带有dev dataset时如果出现错误会发生的报错, | |||||
Example3:: | |||||
import numpy as np | |||||
from torch import nn | |||||
from torch.optim import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP import DataSet | |||||
from fastNLP import AccuracyMetric | |||||
import torch | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, a, b): | |||||
loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2) | |||||
return {'loss': loss} | |||||
def predict(self, a): # 使用predict()进行验证 | |||||
return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key | |||||
model = Model() | |||||
dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) | |||||
dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2}) | |||||
dataset.set_input('a', 'b') | |||||
dev_data.set_input('a') # 这里没有设置target | |||||
trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001), | |||||
dev_data=dev_data, metrics=AccuracyMetric()) | |||||
# 报错信息 | |||||
# ... | |||||
# NameError: | |||||
# Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None) | |||||
# missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)'] | |||||
# unused param: ['output'] | |||||
# target field: [] | |||||
# param from Model.predict(self, a): ['output'] | |||||
# Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a). | |||||
# (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a). | |||||
报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation | |||||
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候 | |||||
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。 | |||||
Example2.2 | |||||
:: | |||||
可以通过check_code_level调节检查的强度。默认为0,即进行检查。 | |||||
import numpy as np | |||||
from torch import nn | |||||
from torch.optim import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP import DataSet | |||||
from fastNLP import L1Loss | |||||
import torch | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, a): | |||||
return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1} | |||||
model = Model() | |||||
dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) | |||||
dataset.set_input('a') | |||||
dataset.set_target('b') | |||||
trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001)) | |||||
# 报错信息如下 | |||||
# input fields after batch(if batch size is 2): | |||||
# a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2]) | |||||
# target fields after batch(if batch size is 2): | |||||
# b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2]) | |||||
# .... | |||||
# NameError: | |||||
# Problems occurred when calling L1Loss.get_loss(self, pred, target) | |||||
# missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)'] | |||||
# unused field: ['b'] | |||||
# unused param: ['pred_b', 'No use'] | |||||
# target field: ['b'] | |||||
# param from Model.forward(self, a): ['pred_b', 'No use'] | |||||
# Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a). | |||||
# (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a). | |||||
报错信息也包含两部分: | |||||
1 第一部分与上面是一样的 | |||||
2 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的); | |||||
报错的原因是因为 `pred` 和 `label` (我们在初始化L1Loss时将target指定为了label)都没有找到。 | |||||
这里'unused field'是DataSet中出现了,但却没有被设置为input或者target的field; | |||||
'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了target的field; | |||||
'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。 | |||||
但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred, | |||||
将DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。 | |||||
下面是带有dev dataset时如果出现错误会发生的报错, | |||||
Example2.3 | |||||
:: | |||||
import numpy as np | |||||
from torch import nn | |||||
from torch.optim import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP import DataSet | |||||
from fastNLP import AccuracyMetric | |||||
import torch | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, a, b): | |||||
loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2) | |||||
return {'loss': loss} | |||||
def predict(self, a): # 使用predict()进行验证 | |||||
return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key | |||||
model = Model() | |||||
dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) | |||||
dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2}) | |||||
dataset.set_input('a', 'b') | |||||
dev_data.set_input('a') # 这里没有设置target | |||||
trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001), | |||||
dev_data=dev_data, metrics=AccuracyMetric()) | |||||
# 报错信息 | |||||
# ... | |||||
# NameError: | |||||
# Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None) | |||||
# missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)'] | |||||
# unused param: ['output'] | |||||
# target field: [] | |||||
# param from Model.predict(self, a): ['output'] | |||||
# Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a). | |||||
# (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a). | |||||
报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation | |||||
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候 | |||||
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。 | |||||
3. Trainer与callback | |||||
可以通过check_code_level调节检查的强度。默认为0,即进行检查。 | |||||
3 Trainer与callback | |||||
虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。 | 虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。 | ||||
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的函数集合,所有的 Callback_ 都具有 | |||||
on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。 | |||||
为了解决这个问题fastNLP引入了callback的机制,:class:`~fastNLP.Callback` 是一种在Trainer训练过程中特定阶段会运行的函数集合, | |||||
所有的 :class:`~fastNLP.Callback` 都具有on_*(比如on_train_start, on_backward_begin)等函数。 | |||||
如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。 | |||||
我们将Train.train()这个函数内部分为以下的阶段,在对应阶段会触发相应的调用。 | 我们将Train.train()这个函数内部分为以下的阶段,在对应阶段会触发相应的调用。 | ||||
@@ -286,12 +296,11 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
callback.on_train_end() # 训练结束 | callback.on_train_end() # 训练结束 | ||||
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里 | callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里 | ||||
fastNLP已经自带了很多callback函数供使用,可以参考 Callback_ 。一些关于callback的例子,请参考 #TODO callback的例子 | |||||
fastNLP已经自带了很多callback函数供使用,可以参考 :class:`~fastNLP.Callback` 。 | |||||
TODO callback的例子 一些关于callback的例子,请参考 | |||||
""" | """ | ||||
import os | import os | ||||
import time | import time | ||||
from datetime import datetime | from datetime import datetime | ||||
@@ -300,32 +309,91 @@ from datetime import timedelta | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
import warnings | |||||
try: | try: | ||||
from tqdm.autonotebook import tqdm | from tqdm.autonotebook import tqdm | ||||
except: | except: | ||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.callback import CallbackManager, CallbackException | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.losses import _prepare_losser | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.sampler import Sampler | |||||
from fastNLP.core.sampler import RandomSampler | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_forward_error | |||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import _get_func_signature | |||||
from fastNLP.core.utils import _get_model_device | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.utils import _move_model_to_device | |||||
from .utils import _pseudo_tqdm as tqdm | |||||
from .batch import Batch | |||||
from .callback import CallbackManager, CallbackException | |||||
from .dataset import DataSet | |||||
from .losses import _prepare_losser | |||||
from .metrics import _prepare_metrics | |||||
from .sampler import Sampler | |||||
from .sampler import RandomSampler | |||||
from .sampler import SequentialSampler | |||||
from .tester import Tester | |||||
from .utils import _CheckError | |||||
from .utils import _build_args | |||||
from .utils import _check_forward_error | |||||
from .utils import _check_loss_evaluate | |||||
from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | |||||
from .utils import _get_model_device | |||||
from .optimizer import Optimizer | |||||
from .utils import _move_model_to_device | |||||
class Trainer(object): | class Trainer(object): | ||||
""" | |||||
别名::class:`fastNLP.Trainer` :class:`fastNLP.core.trainer.Trainer` | |||||
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 | |||||
(1) epoch循环; | |||||
(2) 将数据分成不同的Batch; | |||||
(3) 对Batch进行pad; | |||||
(4) 每个epoch结束或一定step后进行验证集验证; | |||||
(5) 保存获得更好验证性能的模型等。 | |||||
详细的介绍参见 :doc:`fastNLP.core.trainer` | |||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||||
:param nn.modules model: 待训练的模型 | |||||
:param torch.optim.Optimizer optimizer: 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||||
:param int batch_size: 训练和验证的时候的batch大小。 | |||||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||||
:param int n_epochs: 需要优化迭代多少次。 | |||||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||||
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 | |||||
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , | |||||
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 | |||||
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||||
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 | |||||
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, | |||||
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 | |||||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | |||||
保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 | |||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||||
的计算位置进行管理。支持以下的输入: | |||||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||||
可见的第二个GPU中; | |||||
2. torch.device:将模型装载到torch.device上。 | |||||
3. int: 将使用device_id为该值的gpu进行训练 | |||||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||||
:param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | |||||
通过callback机制实现。 可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||||
:param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, | |||||
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | |||||
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | |||||
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | |||||
""" | |||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
batch_size=32, sampler=None, update_every=1, | batch_size=32, sampler=None, update_every=1, | ||||
n_epochs=10, print_every=5, | n_epochs=10, print_every=5, | ||||
@@ -334,74 +402,30 @@ class Trainer(object): | |||||
prefetch=False, use_tqdm=True, device=None, | prefetch=False, use_tqdm=True, device=None, | ||||
callbacks=None, | callbacks=None, | ||||
check_code_level=0): | check_code_level=0): | ||||
""" | |||||
:param DataSet train_data: 训练集 | |||||
:param nn.modules model: 待训练的模型 | |||||
:param torch.optim.Optimizer,None optimizer: 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||||
:param int batch_size: 训练和验证的时候的batch大小。 | |||||
:param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。 | |||||
:param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 | |||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||||
:param int n_epochs: 需要优化迭代多少次。 | |||||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||||
:param DataSet dev_data: 用于做验证的DataSet。 | |||||
:param MetricBase,list(MetricBase) metrics: 验证的评估函数。可以只使用一个Metric,也可以使用多个Metric,通过 | |||||
列表传入。如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||||
则保存当前模型。Metric种类详见 Metric_ 。仅在传入dev_data时有效。 | |||||
:param str,None metric_key: Metric_ 有时会有多个指标,比如 SpanFPreRecMetric_ 中包含了'f', 'pre', 'rec'。此时需 | |||||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有 | |||||
效。 | |||||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模 | |||||
型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 | |||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||||
的计算位置进行管理。支持以下的输入: | |||||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||||
可见的第二个GPU中; | |||||
2. torch.device:将模型装载到torch.device上。 | |||||
3. int: 将使用device_id为该值的gpu进行训练 | |||||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||||
:param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | |||||
通过callback机制实现。 可使用的callback参见 Callback_ 。 | |||||
:param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, | |||||
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | |||||
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | |||||
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | |||||
""" | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | if not isinstance(train_data, DataSet): | ||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | ||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | ||||
# check metrics and dev_data | # check metrics and dev_data | ||||
if (not metrics) and dev_data is not None: | if (not metrics) and dev_data is not None: | ||||
raise ValueError("No metric for dev_data evaluation.") | raise ValueError("No metric for dev_data evaluation.") | ||||
if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
# check update every | # check update every | ||||
assert update_every >= 1, "update_every must be no less than 1." | assert update_every >= 1, "update_every must be no less than 1." | ||||
self.update_every = int(update_every) | self.update_every = int(update_every) | ||||
# check save_path | # check save_path | ||||
if not (save_path is None or isinstance(save_path, str)): | if not (save_path is None or isinstance(save_path, str)): | ||||
raise ValueError("save_path can only be None or `str`.") | raise ValueError("save_path can only be None or `str`.") | ||||
# prepare evaluate | # prepare evaluate | ||||
metrics = _prepare_metrics(metrics) | metrics = _prepare_metrics(metrics) | ||||
# parse metric_key | # parse metric_key | ||||
# increase_better is True. It means the exp result gets better if the indicator increases. | # increase_better is True. It means the exp result gets better if the indicator increases. | ||||
# It is true by default. | # It is true by default. | ||||
@@ -411,19 +435,19 @@ class Trainer(object): | |||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | ||||
elif len(metrics) > 0: | elif len(metrics) > 0: | ||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | ||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
# sampler check | # sampler check | ||||
if sampler is not None and not isinstance(sampler, Sampler): | if sampler is not None and not isinstance(sampler, Sampler): | ||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | ||||
if check_code_level > -1: | if check_code_level > -1: | ||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | ||||
metric_key=metric_key, check_level=check_code_level, | metric_key=metric_key, check_level=check_code_level, | ||||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | ||||
self.train_data = train_data | self.train_data = train_data | ||||
self.dev_data = dev_data # If None, No validation. | self.dev_data = dev_data # If None, No validation. | ||||
self.model = model | self.model = model | ||||
@@ -443,9 +467,9 @@ class Trainer(object): | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | ||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
self.model = _move_model_to_device(self.model, device=device) | self.model = _move_model_to_device(self.model, device=device) | ||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
elif isinstance(optimizer, Optimizer): | elif isinstance(optimizer, Optimizer): | ||||
@@ -454,11 +478,11 @@ class Trainer(object): | |||||
self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) | self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) | ||||
else: | else: | ||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | self.pbar = None | ||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
self.tester = Tester(model=self.model, | self.tester = Tester(model=self.model, | ||||
data=self.dev_data, | data=self.dev_data, | ||||
@@ -466,13 +490,13 @@ class Trainer(object): | |||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
device=None, # 由上面的部分处理device | device=None, # 由上面的部分处理device | ||||
verbose=0) | verbose=0) | ||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
@@ -501,14 +525,14 @@ class Trainer(object): | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
start_time = time.time() | start_time = time.time() | ||||
print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end() | self.callback_manager.on_train_end() | ||||
except (CallbackException, KeyboardInterrupt) as e: | except (CallbackException, KeyboardInterrupt) as e: | ||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | ||||
print( | print( | ||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
@@ -526,9 +550,9 @@ class Trainer(object): | |||||
finally: | finally: | ||||
pass | pass | ||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
return results | return results | ||||
def _train(self): | def _train(self): | ||||
if not self.use_tqdm: | if not self.use_tqdm: | ||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | ||||
@@ -537,7 +561,7 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.epoch = 0 | self.epoch = 0 | ||||
start = time.time() | start = time.time() | ||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
self.pbar = pbar if isinstance(pbar, tqdm) else None | self.pbar = pbar if isinstance(pbar, tqdm) else None | ||||
avg_loss = 0 | avg_loss = 0 | ||||
@@ -556,21 +580,21 @@ class Trainer(object): | |||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | ||||
prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
# edit prediction | # edit prediction | ||||
self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
loss = self._compute_loss(prediction, batch_y).mean() | loss = self._compute_loss(prediction, batch_y).mean() | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
loss = loss / self.update_every | loss = loss / self.update_every | ||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss) | self.callback_manager.on_backward_begin(loss) | ||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
self.callback_manager.on_backward_end() | self.callback_manager.on_backward_end() | ||||
self._update() | self._update() | ||||
self.callback_manager.on_step_end() | self.callback_manager.on_step_end() | ||||
if self.step % self.print_every == 0: | if self.step % self.print_every == 0: | ||||
avg_loss = float(avg_loss) / self.print_every | avg_loss = float(avg_loss) / self.print_every | ||||
if self.use_tqdm: | if self.use_tqdm: | ||||
@@ -584,7 +608,7 @@ class Trainer(object): | |||||
pbar.set_postfix_str(print_output) | pbar.set_postfix_str(print_output) | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | ||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
@@ -593,20 +617,20 @@ class Trainer(object): | |||||
self.n_steps) + \ | self.n_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str + '\n') | pbar.write(eval_str + '\n') | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
pbar.close() | pbar.close() | ||||
self.pbar = None | self.pbar = None | ||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
res = self.tester.test() | res = self.tester.test() | ||||
is_better_eval = False | is_better_eval = False | ||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
@@ -621,7 +645,7 @@ class Trainer(object): | |||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | ||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -633,21 +657,22 @@ class Trainer(object): | |||||
model.eval() | model.eval() | ||||
else: | else: | ||||
model.train() | model.train() | ||||
def _update(self): | def _update(self): | ||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
if self.optimizer is not None and (self.step + 1) % self.update_every == 0: | if self.optimizer is not None and (self.step + 1) % self.update_every == 0: | ||||
self.optimizer.step() | self.optimizer.step() | ||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
y = network(**x) | y = network(**x) | ||||
if not isinstance(y, dict): | if not isinstance(y, dict): | ||||
raise TypeError(f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||||
raise TypeError( | |||||
f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||||
return y | return y | ||||
def _grad_backward(self, loss): | def _grad_backward(self, loss): | ||||
"""Compute gradient with link rules. | """Compute gradient with link rules. | ||||
@@ -658,7 +683,7 @@ class Trainer(object): | |||||
if self.step % self.update_every == 0: | if self.step % self.update_every == 0: | ||||
self.model.zero_grad() | self.model.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
"""Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
@@ -667,7 +692,7 @@ class Trainer(object): | |||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
""" 存储不含有显卡信息的state_dict或model | """ 存储不含有显卡信息的state_dict或model | ||||
:param model: | :param model: | ||||
@@ -690,7 +715,7 @@ class Trainer(object): | |||||
model.cpu() | model.cpu() | ||||
torch.save(model, model_path) | torch.save(model, model_path) | ||||
model.to(self._model_device) | model.to(self._model_device) | ||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# 返回bool值指示是否成功reload模型 | # 返回bool值指示是否成功reload模型 | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
@@ -708,7 +733,7 @@ class Trainer(object): | |||||
else: | else: | ||||
return False | return False | ||||
return True | return True | ||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
@@ -759,7 +784,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
check_level=0): | check_level=0): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_devcie = model.parameters().__next__().device | model_devcie = model.parameters().__next__().device | ||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
@@ -783,13 +808,13 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
print(info_str) | print(info_str) | ||||
_check_forward_error(forward_func=model.forward, dataset=dataset, | _check_forward_error(forward_func=model.forward, dataset=dataset, | ||||
batch_x=batch_x, check_level=check_level) | batch_x=batch_x, check_level=check_level) | ||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
pred_dict = model(**refined_batch_x) | pred_dict = model(**refined_batch_x) | ||||
func_signature = _get_func_signature(model.forward) | func_signature = _get_func_signature(model.forward) | ||||
if not isinstance(pred_dict, dict): | if not isinstance(pred_dict, dict): | ||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | ||||
# loss check | # loss check | ||||
try: | try: | ||||
loss = losser(pred_dict, batch_y) | loss = losser(pred_dict, batch_y) | ||||
@@ -813,7 +838,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
model.zero_grad() | model.zero_grad() | ||||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | ||||
break | break | ||||
if dev_data is not None: | if dev_data is not None: | ||||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | ||||
batch_size=batch_size, verbose=-1) | batch_size=batch_size, verbose=-1) | ||||
@@ -827,7 +852,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||||
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 | # metric_list: 多个用来做评价的指标,来自Trainer的初始化 | ||||
if isinstance(metrics, tuple): | if isinstance(metrics, tuple): | ||||
loss, metrics = metrics | loss, metrics = metrics | ||||
if isinstance(metrics, dict): | if isinstance(metrics, dict): | ||||
if len(metrics) == 1: | if len(metrics) == 1: | ||||
# only single metric, just use it | # only single metric, just use it | ||||
@@ -838,7 +863,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||||
if metrics_name not in metrics: | if metrics_name not in metrics: | ||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | ||||
metric_dict = metrics[metrics_name] | metric_dict = metrics[metrics_name] | ||||
if len(metric_dict) == 1: | if len(metric_dict) == 1: | ||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | ||||
elif len(metric_dict) > 1 and metric_key is None: | elif len(metric_dict) > 1 and metric_key is None: | ||||
@@ -1,3 +1,7 @@ | |||||
""" | |||||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||||
""" | |||||
__all__ = ["cache_results"] | |||||
import _pickle | import _pickle | ||||
import inspect | import inspect | ||||
import os | import os | ||||
@@ -29,6 +33,8 @@ def _prepare_cache_filepath(filepath): | |||||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | # TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | ||||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
""" | """ | ||||
别名::class:`fastNLP.cache_results` :class:`fastNLP.core.uitls.cache_results` | |||||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用 | cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用 | ||||
Example:: | Example:: | ||||