From 62ea4f7fed30671d816364ff40bc937daf7d97a5 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 19 Jan 2019 18:40:43 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0LR=20finder=EF=BC=8C=E7=94=A8?= =?UTF-8?q?=E7=AC=AC=E4=B8=80=E4=B8=AAepoch=E6=89=BE=E6=9C=80=E4=BD=B3lr,?= =?UTF-8?q?=E4=BB=8E=E7=AC=AC=E4=BA=8C=E4=B8=AAepoch=E5=BC=80=E5=A7=8B?= =?UTF-8?q?=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 80 +++++++++++++++++++++++++++++++++++++ test/core/test_callbacks.py | 23 ++++++++--- 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index f354ffc6..e0053124 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -1,3 +1,8 @@ +import torch + +from fastNLP.io.model_io import ModelSaver, ModelLoader + + class Callback(object): """An Interface for all callbacks. @@ -315,6 +320,81 @@ class ControlC(Callback): raise exception # 抛出陌生Error +class SmoothValue(object): + def __init__(self, beta: float): + self.beta, self.n, self.mov_avg = beta, 0, 0 + self.smooth = None + + def add_value(self, val: float) -> None: + "Add `val` to calculate updated smoothed value." + self.n += 1 + self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val + self.smooth = self.mov_avg / (1 - self.beta ** self.n) + + +class LRFinder(Callback): + """fastai lr_finder""" + + def __init__(self, n_batch, start_lr=1e-6, end_lr=10): + """用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 + + :param n_batch: 一个epoch内的iteration数 + :param start_lr: 学习率下界 + :param end_lr: 学习率上界 + """ + super(LRFinder, self).__init__() + self.start_lr, self.end_lr = start_lr, end_lr + self.num_it = n_batch + self.stop = False + self.best_loss = 0. + self.best_lr = None + self.loss_history = [] + self.smooth_value = SmoothValue(0.8) + self.opt = None + scale = (self.end_lr - self.start_lr) / self.num_it + + self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) + self.find = None + self.loader = ModelLoader() + + def before_epoch(self, cur_epoch, total_epoch): + if cur_epoch == 1: + self.opt = self.trainer.optimizer # pytorch optimizer + self.opt.param_groups[0]["lr"] = self.start_lr + # save model + ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) + self.find = True + + def before_backward(self, loss, model): + if self.find: + if torch.isnan(loss) or self.stop is True: + self.stop = True + return + loss_val = loss.detach().cpu().data + self.loss_history.append(loss_val) + self.smooth_value.add_value(loss_val) + if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: + self.best_loss = self.smooth_value.smooth + self.best_lr = self.opt.param_groups[0]["lr"] + + def after_batch(self, *args): + if self.find: + lr = next(self.lr_gen, None) + if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: + self.stop = True + return + self.opt.param_groups[0]["lr"] = lr + # self.loader.load_pytorch(self.trainer.model, "tmp") + + def after_epoch(self, cur_epoch, n_epoch, optimizer): + if cur_epoch == 1: + self.opt.param_groups[0]["lr"] = self.best_lr + self.find = False + # reset model + ModelLoader().load_pytorch(self.trainer.model, "tmp") + print("Model reset. \nFind best lr={}".format(self.best_lr)) + + if __name__ == "__main__": manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager.before_train(10, 11, 12) diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 59f2be1b..d0c1fb13 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -3,7 +3,7 @@ import unittest import numpy as np import torch -from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC +from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, LRFinder from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.losses import BCELoss @@ -52,7 +52,7 @@ class TestCallback(unittest.TestCase): data_set, model = prepare_env() trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=30, + n_epochs=20, batch_size=32, print_every=50, optimizer=SGD(lr=0.1), @@ -67,7 +67,7 @@ class TestCallback(unittest.TestCase): data_set, model = prepare_env() trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=50, + n_epochs=20, batch_size=32, print_every=50, optimizer=SGD(lr=0.01), @@ -83,7 +83,7 @@ class TestCallback(unittest.TestCase): optimizer = torch.optim.SGD(model.parameters(), lr=0.01) trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=50, + n_epochs=5, batch_size=32, print_every=50, optimizer=optimizer, @@ -98,7 +98,7 @@ class TestCallback(unittest.TestCase): data_set, model = prepare_env() trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=50, + n_epochs=5, batch_size=32, print_every=50, optimizer=SGD(lr=0.1), @@ -106,3 +106,16 @@ class TestCallback(unittest.TestCase): use_tqdm=False, callbacks=[ControlC(False)]) trainer.train() + + def test_LRFinder(self): + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=5, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + callbacks=[LRFinder(len(data_set) // 32)]) + trainer.train()