From e8ea6ea3228a28979ef063e3689599aa4c5179e3 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 11 Jan 2019 18:14:58 +0800 Subject: [PATCH] improved callbacks & trainer code structure --- fastNLP/core/callback.py | 63 ++++++++++++++++++++++++++++------------ fastNLP/core/trainer.py | 13 +++++---- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 8b6bfdc2..18230e63 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -12,34 +12,50 @@ class Callback(object): # before the main training loop pass - def before_epoch(self, *args): + def before_epoch(self, cur_epoch, total_epoch): # at the beginning of each epoch pass - def before_batch(self, *args): + def before_batch(self, batch_x, batch_y, indices): # at the beginning of each step/mini-batch pass - def before_loss(self, *args): + def before_loss(self, batch_y, predict_y): # after data_forward, and before loss computation pass - def before_backward(self, *args): + def before_backward(self, loss, model): # after loss computation, and before gradient backward pass - def after_batch(self, *args): + def after_backward(self, model): + # after gradient backward, before optimizer step + pass + + def after_step(self): + # after optimizer step + pass + + def after_batch(self): # at the end of each step/mini-batch pass - def after_epoch(self, *args): + def after_valid(self, eval_result, metric_key, optimizer): + # after validation + pass + + def after_epoch(self, cur_epoch, n_epoch, optimizer): # at the end of each epoch pass - def after_train(self, *args): + def after_train(self, model): # after training loop pass + def on_exception(self, exception, model, indices): + # when exception raised in training + pass + def transfer(func): """装饰器,将对CallbackManager的调用转发到各个Callback子类. @@ -139,38 +155,49 @@ class DummyCallback(Callback): def before_train(self, *arg): print(arg) - def after_epoch(self): - print("after epoch!!!") - return 12 + def after_epoch(self, cur_epoch, n_epoch, optimizer): + print(cur_epoch, n_epoch, optimizer) class EchoCallback(Callback): def before_train(self): print("before_train") - def before_epoch(self): + def before_epoch(self, cur_epoch, total_epoch): print("before_epoch") + print("cur_epoch: ", cur_epoch) + print("total_epoch: ", total_epoch) - def before_batch(self): + def before_batch(self, batch_x, batch_y, indices): print("before_batch") + print("batch_x:", batch_x) + print("batch_y:", batch_y) + print("indices: ", indices) - def before_loss(self): + def before_loss(self, batch_y, predict_y): print("before_loss") + print("batch_y: ", batch_y) + print("predict_y: ", predict_y) - def before_backward(self): + def before_backward(self, loss, model): print("before_backward") + print("loss=", loss) + print("model: ", model) def after_batch(self): print("after_batch") - def after_epoch(self): + def after_epoch(self, cur_epoch, n_epoch, optimizer): print("after_epoch") + print("cur_epoch: ", cur_epoch) + print("n_epoch: ", n_epoch) + print("optimizer", optimizer) - def after_train(self): + def after_train(self, model): print("after_train") + print("model: ", model) if __name__ == "__main__": manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) - manager.before_train(10, 11, 12) - # print(manager.after_epoch()) + manager.after_epoch(3, 10, "optimizer") diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 44d219b4..479ab79a 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -296,23 +296,26 @@ class Trainer(object): epoch = 1 start = time.time() while epoch <= self.n_epochs: - self.callback_manager.before_epoch() + self.callback_manager.before_epoch(epoch, self.n_epochs) data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) for batch_x, batch_y in data_iterator: - self.callback_manager.before_batch() + self.callback_manager.before_batch(batch_x, batch_y, data_iterator.get_batch_indices()) # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) prediction = self._data_forward(self.model, batch_x) - self.callback_manager.before_loss() + self.callback_manager.before_loss(batch_y, prediction) loss = self._compute_loss(prediction, batch_y) - self.callback_manager.before_backward() + self.callback_manager.before_backward(loss, self.model) self._grad_backward(loss) + + self.callback_manager.after_backward(self.model) self._update() + self.callback_manager.after_step() self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) for name, param in self.model.named_parameters(): @@ -338,7 +341,7 @@ class Trainer(object): if self.dev_data and self.validate_every <= 0: self._do_validation(epoch=epoch, step=self.step) epoch += 1 - self.callback_manager.after_epoch() + self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) def _do_validation(self, epoch, step): res = self.tester.test()