Browse Source

improved callbacks & trainer code structure

tags/v0.3.0^2
FengZiYjun 6 years ago
parent
commit
e8ea6ea322
2 changed files with 53 additions and 23 deletions
  1. +45
    -18
      fastNLP/core/callback.py
  2. +8
    -5
      fastNLP/core/trainer.py

+ 45
- 18
fastNLP/core/callback.py View File

@@ -12,34 +12,50 @@ class Callback(object):
# before the main training loop
pass

def before_epoch(self, *args):
def before_epoch(self, cur_epoch, total_epoch):
# at the beginning of each epoch
pass

def before_batch(self, *args):
def before_batch(self, batch_x, batch_y, indices):
# at the beginning of each step/mini-batch
pass

def before_loss(self, *args):
def before_loss(self, batch_y, predict_y):
# after data_forward, and before loss computation
pass

def before_backward(self, *args):
def before_backward(self, loss, model):
# after loss computation, and before gradient backward
pass

def after_batch(self, *args):
def after_backward(self, model):
# after gradient backward, before optimizer step
pass

def after_step(self):
# after optimizer step
pass

def after_batch(self):
# at the end of each step/mini-batch
pass

def after_epoch(self, *args):
def after_valid(self, eval_result, metric_key, optimizer):
# after validation
pass

def after_epoch(self, cur_epoch, n_epoch, optimizer):
# at the end of each epoch
pass

def after_train(self, *args):
def after_train(self, model):
# after training loop
pass

def on_exception(self, exception, model, indices):
# when exception raised in training
pass


def transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类.
@@ -139,38 +155,49 @@ class DummyCallback(Callback):
def before_train(self, *arg):
print(arg)

def after_epoch(self):
print("after epoch!!!")
return 12
def after_epoch(self, cur_epoch, n_epoch, optimizer):
print(cur_epoch, n_epoch, optimizer)


class EchoCallback(Callback):
def before_train(self):
print("before_train")

def before_epoch(self):
def before_epoch(self, cur_epoch, total_epoch):
print("before_epoch")
print("cur_epoch: ", cur_epoch)
print("total_epoch: ", total_epoch)

def before_batch(self):
def before_batch(self, batch_x, batch_y, indices):
print("before_batch")
print("batch_x:", batch_x)
print("batch_y:", batch_y)
print("indices: ", indices)

def before_loss(self):
def before_loss(self, batch_y, predict_y):
print("before_loss")
print("batch_y: ", batch_y)
print("predict_y: ", predict_y)

def before_backward(self):
def before_backward(self, loss, model):
print("before_backward")
print("loss=", loss)
print("model: ", model)

def after_batch(self):
print("after_batch")

def after_epoch(self):
def after_epoch(self, cur_epoch, n_epoch, optimizer):
print("after_epoch")
print("cur_epoch: ", cur_epoch)
print("n_epoch: ", n_epoch)
print("optimizer", optimizer)

def after_train(self):
def after_train(self, model):
print("after_train")
print("model: ", model)


if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train(10, 11, 12)
# print(manager.after_epoch())
manager.after_epoch(3, 10, "optimizer")

+ 8
- 5
fastNLP/core/trainer.py View File

@@ -296,23 +296,26 @@ class Trainer(object):
epoch = 1
start = time.time()
while epoch <= self.n_epochs:
self.callback_manager.before_epoch()
self.callback_manager.before_epoch(epoch, self.n_epochs)

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)

for batch_x, batch_y in data_iterator:
self.callback_manager.before_batch()
self.callback_manager.before_batch(batch_x, batch_y, data_iterator.get_batch_indices())
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)

self.callback_manager.before_loss()
self.callback_manager.before_loss(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y)

self.callback_manager.before_backward()
self.callback_manager.before_backward(loss, self.model)
self._grad_backward(loss)

self.callback_manager.after_backward(self.model)
self._update()
self.callback_manager.after_step()

self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
@@ -338,7 +341,7 @@ class Trainer(object):
if self.dev_data and self.validate_every <= 0:
self._do_validation(epoch=epoch, step=self.step)
epoch += 1
self.callback_manager.after_epoch()
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)

def _do_validation(self, epoch, step):
res = self.tester.test()


Loading…
Cancel
Save