|
|
@@ -1,3 +1,8 @@ |
|
|
|
import torch |
|
|
|
|
|
|
|
from fastNLP.io.model_io import ModelSaver, ModelLoader |
|
|
|
|
|
|
|
|
|
|
|
class Callback(object): |
|
|
|
"""An Interface for all callbacks. |
|
|
|
|
|
|
@@ -315,6 +320,81 @@ class ControlC(Callback): |
|
|
|
raise exception # 抛出陌生Error |
|
|
|
|
|
|
|
|
|
|
|
class SmoothValue(object): |
|
|
|
def __init__(self, beta: float): |
|
|
|
self.beta, self.n, self.mov_avg = beta, 0, 0 |
|
|
|
self.smooth = None |
|
|
|
|
|
|
|
def add_value(self, val: float) -> None: |
|
|
|
"Add `val` to calculate updated smoothed value." |
|
|
|
self.n += 1 |
|
|
|
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val |
|
|
|
self.smooth = self.mov_avg / (1 - self.beta ** self.n) |
|
|
|
|
|
|
|
|
|
|
|
class LRFinder(Callback): |
|
|
|
"""fastai lr_finder""" |
|
|
|
|
|
|
|
def __init__(self, n_batch, start_lr=1e-6, end_lr=10): |
|
|
|
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 |
|
|
|
|
|
|
|
:param n_batch: 一个epoch内的iteration数 |
|
|
|
:param start_lr: 学习率下界 |
|
|
|
:param end_lr: 学习率上界 |
|
|
|
""" |
|
|
|
super(LRFinder, self).__init__() |
|
|
|
self.start_lr, self.end_lr = start_lr, end_lr |
|
|
|
self.num_it = n_batch |
|
|
|
self.stop = False |
|
|
|
self.best_loss = 0. |
|
|
|
self.best_lr = None |
|
|
|
self.loss_history = [] |
|
|
|
self.smooth_value = SmoothValue(0.8) |
|
|
|
self.opt = None |
|
|
|
scale = (self.end_lr - self.start_lr) / self.num_it |
|
|
|
|
|
|
|
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) |
|
|
|
self.find = None |
|
|
|
self.loader = ModelLoader() |
|
|
|
|
|
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
if cur_epoch == 1: |
|
|
|
self.opt = self.trainer.optimizer # pytorch optimizer |
|
|
|
self.opt.param_groups[0]["lr"] = self.start_lr |
|
|
|
# save model |
|
|
|
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) |
|
|
|
self.find = True |
|
|
|
|
|
|
|
def before_backward(self, loss, model): |
|
|
|
if self.find: |
|
|
|
if torch.isnan(loss) or self.stop is True: |
|
|
|
self.stop = True |
|
|
|
return |
|
|
|
loss_val = loss.detach().cpu().data |
|
|
|
self.loss_history.append(loss_val) |
|
|
|
self.smooth_value.add_value(loss_val) |
|
|
|
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: |
|
|
|
self.best_loss = self.smooth_value.smooth |
|
|
|
self.best_lr = self.opt.param_groups[0]["lr"] |
|
|
|
|
|
|
|
def after_batch(self, *args): |
|
|
|
if self.find: |
|
|
|
lr = next(self.lr_gen, None) |
|
|
|
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: |
|
|
|
self.stop = True |
|
|
|
return |
|
|
|
self.opt.param_groups[0]["lr"] = lr |
|
|
|
# self.loader.load_pytorch(self.trainer.model, "tmp") |
|
|
|
|
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
if cur_epoch == 1: |
|
|
|
self.opt.param_groups[0]["lr"] = self.best_lr |
|
|
|
self.find = False |
|
|
|
# reset model |
|
|
|
ModelLoader().load_pytorch(self.trainer.model, "tmp") |
|
|
|
print("Model reset. \nFind best lr={}".format(self.best_lr)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) |
|
|
|
manager.before_train(10, 11, 12) |
|
|
|