Browse Source

添加LR finder,用第一个epoch找最佳lr,从第二个epoch开始训练

tags/v0.3.1^2
FengZiYjun 5 years ago
parent
commit
62ea4f7fed
2 changed files with 98 additions and 5 deletions
  1. +80
    -0
      fastNLP/core/callback.py
  2. +18
    -5
      test/core/test_callbacks.py

+ 80
- 0
fastNLP/core/callback.py View File

@@ -1,3 +1,8 @@
import torch

from fastNLP.io.model_io import ModelSaver, ModelLoader


class Callback(object):
"""An Interface for all callbacks.

@@ -315,6 +320,81 @@ class ControlC(Callback):
raise exception # 抛出陌生Error


class SmoothValue(object):
def __init__(self, beta: float):
self.beta, self.n, self.mov_avg = beta, 0, 0
self.smooth = None

def add_value(self, val: float) -> None:
"Add `val` to calculate updated smoothed value."
self.n += 1
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
self.smooth = self.mov_avg / (1 - self.beta ** self.n)


class LRFinder(Callback):
"""fastai lr_finder"""

def __init__(self, n_batch, start_lr=1e-6, end_lr=10):
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它

:param n_batch: 一个epoch内的iteration数
:param start_lr: 学习率下界
:param end_lr: 学习率上界
"""
super(LRFinder, self).__init__()
self.start_lr, self.end_lr = start_lr, end_lr
self.num_it = n_batch
self.stop = False
self.best_loss = 0.
self.best_lr = None
self.loss_history = []
self.smooth_value = SmoothValue(0.8)
self.opt = None
scale = (self.end_lr - self.start_lr) / self.num_it

self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it))
self.find = None
self.loader = ModelLoader()

def before_epoch(self, cur_epoch, total_epoch):
if cur_epoch == 1:
self.opt = self.trainer.optimizer # pytorch optimizer
self.opt.param_groups[0]["lr"] = self.start_lr
# save model
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
self.find = True

def before_backward(self, loss, model):
if self.find:
if torch.isnan(loss) or self.stop is True:
self.stop = True
return
loss_val = loss.detach().cpu().data
self.loss_history.append(loss_val)
self.smooth_value.add_value(loss_val)
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss:
self.best_loss = self.smooth_value.smooth
self.best_lr = self.opt.param_groups[0]["lr"]

def after_batch(self, *args):
if self.find:
lr = next(self.lr_gen, None)
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss:
self.stop = True
return
self.opt.param_groups[0]["lr"] = lr
# self.loader.load_pytorch(self.trainer.model, "tmp")

def after_epoch(self, cur_epoch, n_epoch, optimizer):
if cur_epoch == 1:
self.opt.param_groups[0]["lr"] = self.best_lr
self.find = False
# reset model
ModelLoader().load_pytorch(self.trainer.model, "tmp")
print("Model reset. \nFind best lr={}".format(self.best_lr))


if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train(10, 11, 12)


+ 18
- 5
test/core/test_callbacks.py View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch

from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, LRFinder
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
@@ -52,7 +52,7 @@ class TestCallback(unittest.TestCase):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=30,
n_epochs=20,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
@@ -67,7 +67,7 @@ class TestCallback(unittest.TestCase):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=50,
n_epochs=20,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.01),
@@ -83,7 +83,7 @@ class TestCallback(unittest.TestCase):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=50,
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=optimizer,
@@ -98,7 +98,7 @@ class TestCallback(unittest.TestCase):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=50,
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
@@ -106,3 +106,16 @@ class TestCallback(unittest.TestCase):
use_tqdm=False,
callbacks=[ControlC(False)])
trainer.train()

def test_LRFinder(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
callbacks=[LRFinder(len(data_set) // 32)])
trainer.train()

Loading…
Cancel
Save