@@ -1,3 +1,8 @@ | |||||
import torch | |||||
from fastNLP.io.model_io import ModelSaver, ModelLoader | |||||
class Callback(object): | class Callback(object): | ||||
"""An Interface for all callbacks. | """An Interface for all callbacks. | ||||
@@ -315,6 +320,81 @@ class ControlC(Callback): | |||||
raise exception # 抛出陌生Error | raise exception # 抛出陌生Error | ||||
class SmoothValue(object): | |||||
def __init__(self, beta: float): | |||||
self.beta, self.n, self.mov_avg = beta, 0, 0 | |||||
self.smooth = None | |||||
def add_value(self, val: float) -> None: | |||||
"Add `val` to calculate updated smoothed value." | |||||
self.n += 1 | |||||
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | |||||
self.smooth = self.mov_avg / (1 - self.beta ** self.n) | |||||
class LRFinder(Callback): | |||||
"""fastai lr_finder""" | |||||
def __init__(self, n_batch, start_lr=1e-6, end_lr=10): | |||||
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||||
:param n_batch: 一个epoch内的iteration数 | |||||
:param start_lr: 学习率下界 | |||||
:param end_lr: 学习率上界 | |||||
""" | |||||
super(LRFinder, self).__init__() | |||||
self.start_lr, self.end_lr = start_lr, end_lr | |||||
self.num_it = n_batch | |||||
self.stop = False | |||||
self.best_loss = 0. | |||||
self.best_lr = None | |||||
self.loss_history = [] | |||||
self.smooth_value = SmoothValue(0.8) | |||||
self.opt = None | |||||
scale = (self.end_lr - self.start_lr) / self.num_it | |||||
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) | |||||
self.find = None | |||||
self.loader = ModelLoader() | |||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
if cur_epoch == 1: | |||||
self.opt = self.trainer.optimizer # pytorch optimizer | |||||
self.opt.param_groups[0]["lr"] = self.start_lr | |||||
# save model | |||||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | |||||
self.find = True | |||||
def before_backward(self, loss, model): | |||||
if self.find: | |||||
if torch.isnan(loss) or self.stop is True: | |||||
self.stop = True | |||||
return | |||||
loss_val = loss.detach().cpu().data | |||||
self.loss_history.append(loss_val) | |||||
self.smooth_value.add_value(loss_val) | |||||
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: | |||||
self.best_loss = self.smooth_value.smooth | |||||
self.best_lr = self.opt.param_groups[0]["lr"] | |||||
def after_batch(self, *args): | |||||
if self.find: | |||||
lr = next(self.lr_gen, None) | |||||
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: | |||||
self.stop = True | |||||
return | |||||
self.opt.param_groups[0]["lr"] = lr | |||||
# self.loader.load_pytorch(self.trainer.model, "tmp") | |||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
if cur_epoch == 1: | |||||
self.opt.param_groups[0]["lr"] = self.best_lr | |||||
self.find = False | |||||
# reset model | |||||
ModelLoader().load_pytorch(self.trainer.model, "tmp") | |||||
print("Model reset. \nFind best lr={}".format(self.best_lr)) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | ||||
manager.before_train(10, 11, 12) | manager.before_train(10, 11, 12) | ||||
@@ -3,7 +3,7 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC | |||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, LRFinder | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
@@ -52,7 +52,7 @@ class TestCallback(unittest.TestCase): | |||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
n_epochs=30, | |||||
n_epochs=20, | |||||
batch_size=32, | batch_size=32, | ||||
print_every=50, | print_every=50, | ||||
optimizer=SGD(lr=0.1), | optimizer=SGD(lr=0.1), | ||||
@@ -67,7 +67,7 @@ class TestCallback(unittest.TestCase): | |||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
n_epochs=50, | |||||
n_epochs=20, | |||||
batch_size=32, | batch_size=32, | ||||
print_every=50, | print_every=50, | ||||
optimizer=SGD(lr=0.01), | optimizer=SGD(lr=0.01), | ||||
@@ -83,7 +83,7 @@ class TestCallback(unittest.TestCase): | |||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
n_epochs=50, | |||||
n_epochs=5, | |||||
batch_size=32, | batch_size=32, | ||||
print_every=50, | print_every=50, | ||||
optimizer=optimizer, | optimizer=optimizer, | ||||
@@ -98,7 +98,7 @@ class TestCallback(unittest.TestCase): | |||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
n_epochs=50, | |||||
n_epochs=5, | |||||
batch_size=32, | batch_size=32, | ||||
print_every=50, | print_every=50, | ||||
optimizer=SGD(lr=0.1), | optimizer=SGD(lr=0.1), | ||||
@@ -106,3 +106,16 @@ class TestCallback(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
callbacks=[ControlC(False)]) | callbacks=[ControlC(False)]) | ||||
trainer.train() | trainer.train() | ||||
def test_LRFinder(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[LRFinder(len(data_set) // 32)]) | |||||
trainer.train() |