|
|
@@ -12,34 +12,50 @@ class Callback(object): |
|
|
|
# before the main training loop |
|
|
|
pass |
|
|
|
|
|
|
|
def before_epoch(self, *args): |
|
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
# at the beginning of each epoch |
|
|
|
pass |
|
|
|
|
|
|
|
def before_batch(self, *args): |
|
|
|
def before_batch(self, batch_x, batch_y, indices): |
|
|
|
# at the beginning of each step/mini-batch |
|
|
|
pass |
|
|
|
|
|
|
|
def before_loss(self, *args): |
|
|
|
def before_loss(self, batch_y, predict_y): |
|
|
|
# after data_forward, and before loss computation |
|
|
|
pass |
|
|
|
|
|
|
|
def before_backward(self, *args): |
|
|
|
def before_backward(self, loss, model): |
|
|
|
# after loss computation, and before gradient backward |
|
|
|
pass |
|
|
|
|
|
|
|
def after_batch(self, *args): |
|
|
|
def after_backward(self, model): |
|
|
|
# after gradient backward, before optimizer step |
|
|
|
pass |
|
|
|
|
|
|
|
def after_step(self): |
|
|
|
# after optimizer step |
|
|
|
pass |
|
|
|
|
|
|
|
def after_batch(self): |
|
|
|
# at the end of each step/mini-batch |
|
|
|
pass |
|
|
|
|
|
|
|
def after_epoch(self, *args): |
|
|
|
def after_valid(self, eval_result, metric_key, optimizer): |
|
|
|
# after validation |
|
|
|
pass |
|
|
|
|
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
# at the end of each epoch |
|
|
|
pass |
|
|
|
|
|
|
|
def after_train(self, *args): |
|
|
|
def after_train(self, model): |
|
|
|
# after training loop |
|
|
|
pass |
|
|
|
|
|
|
|
def on_exception(self, exception, model, indices): |
|
|
|
# when exception raised in training |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def transfer(func): |
|
|
|
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. |
|
|
@@ -139,38 +155,49 @@ class DummyCallback(Callback): |
|
|
|
def before_train(self, *arg): |
|
|
|
print(arg) |
|
|
|
|
|
|
|
def after_epoch(self): |
|
|
|
print("after epoch!!!") |
|
|
|
return 12 |
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
print(cur_epoch, n_epoch, optimizer) |
|
|
|
|
|
|
|
|
|
|
|
class EchoCallback(Callback): |
|
|
|
def before_train(self): |
|
|
|
print("before_train") |
|
|
|
|
|
|
|
def before_epoch(self): |
|
|
|
def before_epoch(self, cur_epoch, total_epoch): |
|
|
|
print("before_epoch") |
|
|
|
print("cur_epoch: ", cur_epoch) |
|
|
|
print("total_epoch: ", total_epoch) |
|
|
|
|
|
|
|
def before_batch(self): |
|
|
|
def before_batch(self, batch_x, batch_y, indices): |
|
|
|
print("before_batch") |
|
|
|
print("batch_x:", batch_x) |
|
|
|
print("batch_y:", batch_y) |
|
|
|
print("indices: ", indices) |
|
|
|
|
|
|
|
def before_loss(self): |
|
|
|
def before_loss(self, batch_y, predict_y): |
|
|
|
print("before_loss") |
|
|
|
print("batch_y: ", batch_y) |
|
|
|
print("predict_y: ", predict_y) |
|
|
|
|
|
|
|
def before_backward(self): |
|
|
|
def before_backward(self, loss, model): |
|
|
|
print("before_backward") |
|
|
|
print("loss=", loss) |
|
|
|
print("model: ", model) |
|
|
|
|
|
|
|
def after_batch(self): |
|
|
|
print("after_batch") |
|
|
|
|
|
|
|
def after_epoch(self): |
|
|
|
def after_epoch(self, cur_epoch, n_epoch, optimizer): |
|
|
|
print("after_epoch") |
|
|
|
print("cur_epoch: ", cur_epoch) |
|
|
|
print("n_epoch: ", n_epoch) |
|
|
|
print("optimizer", optimizer) |
|
|
|
|
|
|
|
def after_train(self): |
|
|
|
def after_train(self, model): |
|
|
|
print("after_train") |
|
|
|
print("model: ", model) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) |
|
|
|
manager.before_train(10, 11, 12) |
|
|
|
# print(manager.after_epoch()) |
|
|
|
manager.after_epoch(3, 10, "optimizer") |