|
@@ -1,7 +1,62 @@ |
|
|
""" |
|
|
|
|
|
callback模块实现了 fastNLP 中的Callback类,用于增强 :class:`~fastNLP.Trainer` 类, |
|
|
|
|
|
|
|
|
r""" |
|
|
|
|
|
callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:`~fastNLP.Trainer` 类, |
|
|
|
|
|
|
|
|
|
|
|
我们将 :meth:`~fastNLP.Train.train` 这个函数内部分为以下的阶段,在对应阶段会触发相应的调用:: |
|
|
|
|
|
|
|
|
|
|
|
callback.on_train_begin() # 开始进行训练 |
|
|
|
|
|
for i in range(1, n_epochs+1): |
|
|
|
|
|
callback.on_epoch_begin() # 开始新的epoch |
|
|
|
|
|
for batch_x, batch_y in Batch: |
|
|
|
|
|
callback.on_batch_begin(batch_x, batch_y, indices) # batch_x是设置为input的field,batch_y是设置为target的field |
|
|
|
|
|
获取模型输出 |
|
|
|
|
|
callback.on_loss_begin() |
|
|
|
|
|
计算loss |
|
|
|
|
|
callback.on_backward_begin() # 可以进行一些检查,比如loss是否为None |
|
|
|
|
|
反向梯度回传 |
|
|
|
|
|
callback.on_backward_end() # 进行梯度截断等 |
|
|
|
|
|
进行参数更新 |
|
|
|
|
|
callback.on_step_end() |
|
|
|
|
|
callback.on_batch_end() |
|
|
|
|
|
# 根据设置进行evaluation,比如这是本epoch最后一个batch或者达到一定step |
|
|
|
|
|
if do evaluation: |
|
|
|
|
|
callback.on_valid_begin() |
|
|
|
|
|
进行dev data上的验证 |
|
|
|
|
|
callback.on_valid_end() # 可以进行在其它数据集上进行验证 |
|
|
|
|
|
callback.on_epoch_end() # epoch结束调用 |
|
|
|
|
|
callback.on_train_end() # 训练结束 |
|
|
|
|
|
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里 |
|
|
|
|
|
|
|
|
关于Trainer的详细文档,请参见 :doc:`trainer 模块<fastNLP.core.trainer>` |
|
|
关于Trainer的详细文档,请参见 :doc:`trainer 模块<fastNLP.core.trainer>` |
|
|
|
|
|
|
|
|
|
|
|
如下面的例子所示,我们可以使用内置的 callback 类,或者继承 :class:`~fastNLP.core.callback.Callback` |
|
|
|
|
|
定义自己的 callback 类:: |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP import Callback, EarlyStopCallback, Trainer, CrossEntropyLoss, AccuracyMetric |
|
|
|
|
|
from fastNLP.models import CNNText |
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
class MyCallback(Callback): |
|
|
|
|
|
def on_epoch_end(self): |
|
|
|
|
|
print('{:d}ms\n\n'.format(round((time.time()-start_time)*1000))) |
|
|
|
|
|
|
|
|
|
|
|
model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1) |
|
|
|
|
|
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, loss=CrossEntropyLoss(), |
|
|
|
|
|
metrics=AccuracyMetric(), callbacks=[MyCallback(),EarlyStopCallback(10)]) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
__all__ = [ |
|
|
|
|
|
"Callback", |
|
|
|
|
|
"GradientClipCallback", |
|
|
|
|
|
"EarlyStopCallback", |
|
|
|
|
|
"TensorboardCallback", |
|
|
|
|
|
"LRScheduler", |
|
|
|
|
|
"ControlC", |
|
|
|
|
|
|
|
|
|
|
|
"CallbackException", |
|
|
|
|
|
"EarlyStopError" |
|
|
|
|
|
] |
|
|
import os |
|
|
import os |
|
|
import torch |
|
|
import torch |
|
|
from ..io.model_io import ModelSaver, ModelLoader |
|
|
from ..io.model_io import ModelSaver, ModelLoader |
|
@@ -19,7 +74,7 @@ class Callback(object): |
|
|
Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 |
|
|
Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 |
|
|
如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, |
|
|
如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, |
|
|
具体调用时机可以通过 :doc:`trainer 模块<fastNLP.core.trainer>` 查看。 |
|
|
具体调用时机可以通过 :doc:`trainer 模块<fastNLP.core.trainer>` 查看。 |
|
|
这是Callback的基类,所有的callback必须继承自这个类(参见 :doc:`callback 模块 <fastNLP.core.callback>` ) |
|
|
|
|
|
|
|
|
这是Callback的基类,所有的callback必须继承自这个类 |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
@@ -236,7 +291,7 @@ class CallbackManager(Callback): |
|
|
|
|
|
|
|
|
for env_name, env_val in env.items(): |
|
|
for env_name, env_val in env.items(): |
|
|
for callback in self.callbacks: |
|
|
for callback in self.callbacks: |
|
|
print(callback, env_name, env_val ) |
|
|
|
|
|
|
|
|
print(callback, env_name, env_val) |
|
|
setattr(callback, '_' + env_name, env_val) # Callback.trainer |
|
|
setattr(callback, '_' + env_name, env_val) # Callback.trainer |
|
|
|
|
|
|
|
|
@_transfer |
|
|
@_transfer |
|
@@ -294,12 +349,15 @@ class CallbackManager(Callback): |
|
|
|
|
|
|
|
|
class GradientClipCallback(Callback): |
|
|
class GradientClipCallback(Callback): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` |
|
|
|
|
|
|
|
|
每次backward前,将parameter的gradient clip到某个范围。 |
|
|
每次backward前,将parameter的gradient clip到某个范围。 |
|
|
|
|
|
|
|
|
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer |
|
|
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer |
|
|
的model中所有参数进行clip |
|
|
的model中所有参数进行clip |
|
|
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 |
|
|
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 |
|
|
:param str clip_type: 支持'norm', 'value'两种:: |
|
|
|
|
|
|
|
|
:param str clip_type: 支持'norm', 'value' |
|
|
|
|
|
两种:: |
|
|
|
|
|
|
|
|
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] |
|
|
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] |
|
|
|
|
|
|
|
@@ -331,8 +389,11 @@ class GradientClipCallback(Callback): |
|
|
|
|
|
|
|
|
class EarlyStopCallback(Callback): |
|
|
class EarlyStopCallback(Callback): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.EarlyStopCallback` :class:`fastNLP.core.callback.EarlyStopCallback` |
|
|
|
|
|
|
|
|
|
|
|
多少个epoch没有变好就停止训练,相关类 :class:`EarlyStopError` |
|
|
|
|
|
|
|
|
:param int patience: 多少个epoch没有变好就停止训练 |
|
|
|
|
|
|
|
|
:param int patience: epoch的数量 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, patience): |
|
|
def __init__(self, patience): |
|
@@ -358,11 +419,10 @@ class EarlyStopCallback(Callback): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LRScheduler(Callback): |
|
|
class LRScheduler(Callback): |
|
|
"""对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 |
|
|
|
|
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` |
|
|
|
|
|
|
|
|
from fastNLP import LRScheduler |
|
|
|
|
|
|
|
|
对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 |
|
|
|
|
|
|
|
|
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler |
|
|
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler |
|
|
""" |
|
|
""" |
|
@@ -382,6 +442,7 @@ class LRScheduler(Callback): |
|
|
|
|
|
|
|
|
class ControlC(Callback): |
|
|
class ControlC(Callback): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.ControlC` :class:`fastNLP.core.callback.ControlC` |
|
|
|
|
|
|
|
|
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer |
|
|
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer |
|
|
""" |
|
|
""" |
|
@@ -418,6 +479,8 @@ class SmoothValue(object): |
|
|
|
|
|
|
|
|
class LRFinder(Callback): |
|
|
class LRFinder(Callback): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.LRFinder` :class:`fastNLP.core.callback.LRFinder` |
|
|
|
|
|
|
|
|
用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 |
|
|
用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 |
|
|
|
|
|
|
|
|
:param float start_lr: 学习率下界 |
|
|
:param float start_lr: 学习率下界 |
|
@@ -442,7 +505,7 @@ class LRFinder(Callback): |
|
|
def lr_gen(self): |
|
|
def lr_gen(self): |
|
|
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch |
|
|
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch |
|
|
return (self.start_lr + scale * (step + 1) for step in range(self.batch_per_epoch)) |
|
|
return (self.start_lr + scale * (step + 1) for step in range(self.batch_per_epoch)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def num_it(self): |
|
|
def num_it(self): |
|
|
return self.batch_per_epoch |
|
|
return self.batch_per_epoch |
|
@@ -487,10 +550,17 @@ class LRFinder(Callback): |
|
|
|
|
|
|
|
|
class TensorboardCallback(Callback): |
|
|
class TensorboardCallback(Callback): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.TensorboardCallback` :class:`fastNLP.core.callback.TensorboardCallback` |
|
|
|
|
|
|
|
|
接受以下一个或多个字符串作为参数: |
|
|
接受以下一个或多个字符串作为参数: |
|
|
- "model" |
|
|
- "model" |
|
|
- "loss" |
|
|
- "loss" |
|
|
- "metric" |
|
|
- "metric" |
|
|
|
|
|
|
|
|
|
|
|
.. warning:: |
|
|
|
|
|
fastNLP 已停止对此功能的维护,请等待 fastNLP 兼容 PyTorch1.1 的下一个版本。 |
|
|
|
|
|
或者使用和 fastNLP 高度配合的 fitlog(参见 :doc:`/user/with_fitlog` )。 |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, *options): |
|
|
def __init__(self, *options): |
|
|