|
|
@@ -130,7 +130,8 @@ class Callback(object): |
|
|
|
|
|
|
|
@property |
|
|
|
def pbar(self): |
|
|
|
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。""" |
|
|
|
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在 |
|
|
|
on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。""" |
|
|
|
return self._trainer.pbar |
|
|
|
|
|
|
|
@property |
|
|
@@ -440,7 +441,7 @@ class LRScheduler(Callback): |
|
|
|
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") |
|
|
|
|
|
|
|
def on_epoch_begin(self): |
|
|
|
self.scheduler.step() |
|
|
|
self.scheduler.step(self.epoch) |
|
|
|
|
|
|
|
|
|
|
|
class ControlC(Callback): |
|
|
@@ -526,7 +527,7 @@ class LRFinder(Callback): |
|
|
|
if torch.isnan(loss) or self.stop is True: |
|
|
|
self.stop = True |
|
|
|
return |
|
|
|
loss_val = loss.detach().cpu().data |
|
|
|
loss_val = loss.detach().mean().item() |
|
|
|
self.loss_history.append(loss_val) |
|
|
|
self.smooth_value.add_value(loss_val) |
|
|
|
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: |
|
|
@@ -548,7 +549,7 @@ class LRFinder(Callback): |
|
|
|
self.find = False |
|
|
|
# reset model |
|
|
|
ModelLoader().load_pytorch(self.trainer.model, "tmp") |
|
|
|
print("Model reset. \nFind best lr={}".format(self.best_lr)) |
|
|
|
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) |
|
|
|
|
|
|
|
|
|
|
|
class TensorboardCallback(Callback): |
|
|
|