|
@@ -17,37 +17,40 @@ class Callback(object): |
|
|
super(Callback, self).__init__() |
|
|
super(Callback, self).__init__() |
|
|
self.trainer = None # 在Trainer内部被重新赋值 |
|
|
self.trainer = None # 在Trainer内部被重新赋值 |
|
|
|
|
|
|
|
|
def before_train(self): |
|
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
# before the main training loop |
|
|
# before the main training loop |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
|
|
|
|
|
def on_epoch_begin(self, cur_epoch, total_epoch): |
|
|
# at the beginning of each epoch |
|
|
# at the beginning of each epoch |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def before_batch(self, batch_x, batch_y, indices): |
|
|
|
|
|
|
|
|
def on_batch_begin(self, batch_x, batch_y, indices): |
|
|
# at the beginning of each step/mini-batch |
|
|
# at the beginning of each step/mini-batch |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def before_loss(self, batch_y, predict_y): |
|
|
|
|
|
|
|
|
def on_loss_begin(self, batch_y, predict_y): |
|
|
# after data_forward, and before loss computation |
|
|
# after data_forward, and before loss computation |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def before_backward(self, loss, model): |
|
|
|
|
|
|
|
|
def on_backward_begin(self, loss, model): |
|
|
# after loss computation, and before gradient backward |
|
|
# after loss computation, and before gradient backward |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def after_backward(self, model): |
|
|
|
|
|
|
|
|
def on_backward_end(self, model): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def after_step(self, optimizer): |
|
|
|
|
|
|
|
|
def on_step_end(self, optimizer): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def after_batch(self, *args): |
|
|
|
|
|
|
|
|
def on_batch_end(self, *args): |
|
|
# at the end of each step/mini-batch |
|
|
# at the end of each step/mini-batch |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def after_valid(self, eval_result, metric_key, optimizer): |
|
|
|
|
|
|
|
|
def on_valid_begin(self): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def on_valid_end(self, eval_result, metric_key, optimizer): |
|
|
""" |
|
|
""" |
|
|
每次执行验证机的evaluation后会调用。传入eval_result |
|
|
每次执行验证机的evaluation后会调用。传入eval_result |
|
|
|
|
|
|
|
@@ -58,7 +61,7 @@ class Callback(object): |
|
|
""" |
|
|
""" |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
|
|
|
|
|
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): |
|
|
""" |
|
|
""" |
|
|
每个epoch结束将会调用该方法 |
|
|
每个epoch结束将会调用该方法 |
|
|
|
|
|
|
|
@@ -69,7 +72,7 @@ class Callback(object): |
|
|
""" |
|
|
""" |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
def after_train(self, model): |
|
|
|
|
|
|
|
|
def on_train_end(self, model): |
|
|
""" |
|
|
""" |
|
|
训练结束,调用该方法 |
|
|
训练结束,调用该方法 |
|
|
|
|
|
|
|
@@ -134,47 +137,51 @@ class CallbackManager(Callback): |
|
|
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") |
|
|
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def before_train(self): |
|
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
@transfer |
|
|
|
|
|
def on_epoch_begin(self, cur_epoch, total_epoch): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
|
|
|
|
|
def on_batch_begin(self, batch_x, batch_y, indices): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def before_batch(self, batch_x, batch_y, indices): |
|
|
|
|
|
|
|
|
def on_loss_begin(self, batch_y, predict_y): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def before_loss(self, batch_y, predict_y): |
|
|
|
|
|
|
|
|
def on_backward_begin(self, loss, model): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def before_backward(self, loss, model): |
|
|
|
|
|
|
|
|
def on_backward_end(self, model): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def after_backward(self, model): |
|
|
|
|
|
|
|
|
def on_step_end(self, optimizer): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def after_step(self, optimizer): |
|
|
|
|
|
|
|
|
def on_batch_end(self): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def after_batch(self): |
|
|
|
|
|
|
|
|
def on_valid_begin(self): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def after_valid(self, eval_result, metric_key, optimizer): |
|
|
|
|
|
|
|
|
def on_valid_end(self, eval_result, metric_key, optimizer): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
|
|
|
|
|
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
|
def after_train(self, model): |
|
|
|
|
|
|
|
|
def on_train_end(self, model): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@transfer |
|
|
@transfer |
|
@@ -183,36 +190,36 @@ class CallbackManager(Callback): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DummyCallback(Callback): |
|
|
class DummyCallback(Callback): |
|
|
def before_train(self, *arg): |
|
|
|
|
|
|
|
|
def on_train_begin(self, *arg): |
|
|
print(arg) |
|
|
print(arg) |
|
|
|
|
|
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
|
|
|
|
|
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): |
|
|
print(cur_epoch, n_epoch, optimizer) |
|
|
print(cur_epoch, n_epoch, optimizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EchoCallback(Callback): |
|
|
class EchoCallback(Callback): |
|
|
def before_train(self): |
|
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
print("before_train") |
|
|
print("before_train") |
|
|
|
|
|
|
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
|
|
|
|
|
def on_epoch_begin(self, cur_epoch, total_epoch): |
|
|
print("before_epoch") |
|
|
print("before_epoch") |
|
|
|
|
|
|
|
|
def before_batch(self, batch_x, batch_y, indices): |
|
|
|
|
|
|
|
|
def on_batch_begin(self, batch_x, batch_y, indices): |
|
|
print("before_batch") |
|
|
print("before_batch") |
|
|
|
|
|
|
|
|
def before_loss(self, batch_y, predict_y): |
|
|
|
|
|
|
|
|
def on_loss_begin(self, batch_y, predict_y): |
|
|
print("before_loss") |
|
|
print("before_loss") |
|
|
|
|
|
|
|
|
def before_backward(self, loss, model): |
|
|
|
|
|
|
|
|
def on_backward_begin(self, loss, model): |
|
|
print("before_backward") |
|
|
print("before_backward") |
|
|
|
|
|
|
|
|
def after_batch(self): |
|
|
|
|
|
|
|
|
def on_batch_end(self): |
|
|
print("after_batch") |
|
|
print("after_batch") |
|
|
|
|
|
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
|
|
|
|
|
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): |
|
|
print("after_epoch") |
|
|
print("after_epoch") |
|
|
|
|
|
|
|
|
def after_train(self, model): |
|
|
|
|
|
|
|
|
def on_train_end(self, model): |
|
|
print("after_train") |
|
|
print("after_train") |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -240,7 +247,7 @@ class GradientClipCallback(Callback): |
|
|
self.parameters = parameters |
|
|
self.parameters = parameters |
|
|
self.clip_value = clip_value |
|
|
self.clip_value = clip_value |
|
|
|
|
|
|
|
|
def after_backward(self, model): |
|
|
|
|
|
|
|
|
def on_backward_end(self, model): |
|
|
self.clip_fun(model.parameters(), self.clip_value) |
|
|
self.clip_fun(model.parameters(), self.clip_value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -266,7 +273,7 @@ class EarlyStopCallback(Callback): |
|
|
self.wait = 0 |
|
|
self.wait = 0 |
|
|
self.epoch = 0 |
|
|
self.epoch = 0 |
|
|
|
|
|
|
|
|
def after_valid(self, eval_result, metric_key, optimizer): |
|
|
|
|
|
|
|
|
def on_valid_end(self, eval_result, metric_key, optimizer): |
|
|
self.epoch += 1 |
|
|
self.epoch += 1 |
|
|
if not self.trainer._better_eval_result(eval_result): |
|
|
if not self.trainer._better_eval_result(eval_result): |
|
|
# current result is getting worse |
|
|
# current result is getting worse |
|
@@ -297,7 +304,7 @@ class LRScheduler(Callback): |
|
|
else: |
|
|
else: |
|
|
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") |
|
|
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") |
|
|
|
|
|
|
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
|
|
|
|
|
def on_epoch_begin(self, cur_epoch, total_epoch): |
|
|
self.scheduler.step() |
|
|
self.scheduler.step() |
|
|
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) |
|
|
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) |
|
|
|
|
|
|
|
@@ -359,7 +366,7 @@ class LRFinder(Callback): |
|
|
self.find = None |
|
|
self.find = None |
|
|
self.loader = ModelLoader() |
|
|
self.loader = ModelLoader() |
|
|
|
|
|
|
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
|
|
|
|
|
def on_epoch_begin(self, cur_epoch, total_epoch): |
|
|
if cur_epoch == 1: |
|
|
if cur_epoch == 1: |
|
|
self.opt = self.trainer.optimizer # pytorch optimizer |
|
|
self.opt = self.trainer.optimizer # pytorch optimizer |
|
|
self.opt.param_groups[0]["lr"] = self.start_lr |
|
|
self.opt.param_groups[0]["lr"] = self.start_lr |
|
@@ -367,7 +374,7 @@ class LRFinder(Callback): |
|
|
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) |
|
|
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) |
|
|
self.find = True |
|
|
self.find = True |
|
|
|
|
|
|
|
|
def before_backward(self, loss, model): |
|
|
|
|
|
|
|
|
def on_backward_begin(self, loss, model): |
|
|
if self.find: |
|
|
if self.find: |
|
|
if torch.isnan(loss) or self.stop is True: |
|
|
if torch.isnan(loss) or self.stop is True: |
|
|
self.stop = True |
|
|
self.stop = True |
|
@@ -379,7 +386,7 @@ class LRFinder(Callback): |
|
|
self.best_loss = self.smooth_value.smooth |
|
|
self.best_loss = self.smooth_value.smooth |
|
|
self.best_lr = self.opt.param_groups[0]["lr"] |
|
|
self.best_lr = self.opt.param_groups[0]["lr"] |
|
|
|
|
|
|
|
|
def after_batch(self, *args): |
|
|
|
|
|
|
|
|
def on_batch_end(self, *args): |
|
|
if self.find: |
|
|
if self.find: |
|
|
lr = next(self.lr_gen, None) |
|
|
lr = next(self.lr_gen, None) |
|
|
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: |
|
|
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: |
|
@@ -388,7 +395,7 @@ class LRFinder(Callback): |
|
|
self.opt.param_groups[0]["lr"] = lr |
|
|
self.opt.param_groups[0]["lr"] = lr |
|
|
# self.loader.load_pytorch(self.trainer.model, "tmp") |
|
|
# self.loader.load_pytorch(self.trainer.model, "tmp") |
|
|
|
|
|
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
|
|
|
|
|
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): |
|
|
if cur_epoch == 1: |
|
|
if cur_epoch == 1: |
|
|
self.opt.param_groups[0]["lr"] = self.best_lr |
|
|
self.opt.param_groups[0]["lr"] = self.best_lr |
|
|
self.find = False |
|
|
self.find = False |
|
@@ -415,7 +422,7 @@ class TensorboardCallback(Callback): |
|
|
self._summary_writer = None |
|
|
self._summary_writer = None |
|
|
self.graph_added = False |
|
|
self.graph_added = False |
|
|
|
|
|
|
|
|
def before_train(self): |
|
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
save_dir = self.trainer.save_path |
|
|
save_dir = self.trainer.save_path |
|
|
if save_dir is None: |
|
|
if save_dir is None: |
|
|
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) |
|
|
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) |
|
@@ -423,7 +430,7 @@ class TensorboardCallback(Callback): |
|
|
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) |
|
|
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) |
|
|
self._summary_writer = SummaryWriter(path) |
|
|
self._summary_writer = SummaryWriter(path) |
|
|
|
|
|
|
|
|
def before_batch(self, batch_x, batch_y, indices): |
|
|
|
|
|
|
|
|
def on_batch_begin(self, batch_x, batch_y, indices): |
|
|
if "model" in self.options and self.graph_added is False: |
|
|
if "model" in self.options and self.graph_added is False: |
|
|
# tesorboardX 这里有大bug,暂时没法画模型图 |
|
|
# tesorboardX 这里有大bug,暂时没法画模型图 |
|
|
# from fastNLP.core.utils import _build_args |
|
|
# from fastNLP.core.utils import _build_args |
|
@@ -433,7 +440,7 @@ class TensorboardCallback(Callback): |
|
|
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) |
|
|
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) |
|
|
self.graph_added = True |
|
|
self.graph_added = True |
|
|
|
|
|
|
|
|
def before_backward(self, loss, model): |
|
|
|
|
|
|
|
|
def on_backward_begin(self, loss, model): |
|
|
if "loss" in self.options: |
|
|
if "loss" in self.options: |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) |
|
|
|
|
|
|
|
@@ -445,14 +452,14 @@ class TensorboardCallback(Callback): |
|
|
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), |
|
|
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), |
|
|
global_step=self.trainer.step) |
|
|
global_step=self.trainer.step) |
|
|
|
|
|
|
|
|
def after_valid(self, eval_result, metric_key, optimizer): |
|
|
|
|
|
|
|
|
def on_valid_end(self, eval_result, metric_key, optimizer): |
|
|
if "metric" in self.options: |
|
|
if "metric" in self.options: |
|
|
for name, metric in eval_result.items(): |
|
|
for name, metric in eval_result.items(): |
|
|
for metric_key, metric_val in metric.items(): |
|
|
for metric_key, metric_val in metric.items(): |
|
|
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, |
|
|
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, |
|
|
global_step=self.trainer.step) |
|
|
global_step=self.trainer.step) |
|
|
|
|
|
|
|
|
def after_train(self, model): |
|
|
|
|
|
|
|
|
def on_train_end(self, model): |
|
|
self._summary_writer.close() |
|
|
self._summary_writer.close() |
|
|
del self._summary_writer |
|
|
del self._summary_writer |
|
|
|
|
|
|
|
@@ -464,5 +471,5 @@ class TensorboardCallback(Callback): |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) |
|
|
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) |
|
|
manager.before_train(10, 11, 12) |
|
|
|
|
|
|
|
|
manager.on_train_begin(10, 11, 12) |
|
|
# print(manager.after_epoch()) |
|
|
# print(manager.after_epoch()) |