Browse Source

improved callbacks & trainer code structure

tags/v0.3.0^2
FengZiYjun 6 years ago
parent
commit
e8ea6ea322
2 changed files with 53 additions and 23 deletions
  1. +45
    -18
      fastNLP/core/callback.py
  2. +8
    -5
      fastNLP/core/trainer.py

+ 45
- 18
fastNLP/core/callback.py View File

@@ -12,34 +12,50 @@ class Callback(object):
# before the main training loop # before the main training loop
pass pass


def before_epoch(self, *args):
def before_epoch(self, cur_epoch, total_epoch):
# at the beginning of each epoch # at the beginning of each epoch
pass pass


def before_batch(self, *args):
def before_batch(self, batch_x, batch_y, indices):
# at the beginning of each step/mini-batch # at the beginning of each step/mini-batch
pass pass


def before_loss(self, *args):
def before_loss(self, batch_y, predict_y):
# after data_forward, and before loss computation # after data_forward, and before loss computation
pass pass


def before_backward(self, *args):
def before_backward(self, loss, model):
# after loss computation, and before gradient backward # after loss computation, and before gradient backward
pass pass


def after_batch(self, *args):
def after_backward(self, model):
# after gradient backward, before optimizer step
pass

def after_step(self):
# after optimizer step
pass

def after_batch(self):
# at the end of each step/mini-batch # at the end of each step/mini-batch
pass pass


def after_epoch(self, *args):
def after_valid(self, eval_result, metric_key, optimizer):
# after validation
pass

def after_epoch(self, cur_epoch, n_epoch, optimizer):
# at the end of each epoch # at the end of each epoch
pass pass


def after_train(self, *args):
def after_train(self, model):
# after training loop # after training loop
pass pass


def on_exception(self, exception, model, indices):
# when exception raised in training
pass



def transfer(func): def transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. """装饰器,将对CallbackManager的调用转发到各个Callback子类.
@@ -139,38 +155,49 @@ class DummyCallback(Callback):
def before_train(self, *arg): def before_train(self, *arg):
print(arg) print(arg)


def after_epoch(self):
print("after epoch!!!")
return 12
def after_epoch(self, cur_epoch, n_epoch, optimizer):
print(cur_epoch, n_epoch, optimizer)




class EchoCallback(Callback): class EchoCallback(Callback):
def before_train(self): def before_train(self):
print("before_train") print("before_train")


def before_epoch(self):
def before_epoch(self, cur_epoch, total_epoch):
print("before_epoch") print("before_epoch")
print("cur_epoch: ", cur_epoch)
print("total_epoch: ", total_epoch)


def before_batch(self):
def before_batch(self, batch_x, batch_y, indices):
print("before_batch") print("before_batch")
print("batch_x:", batch_x)
print("batch_y:", batch_y)
print("indices: ", indices)


def before_loss(self):
def before_loss(self, batch_y, predict_y):
print("before_loss") print("before_loss")
print("batch_y: ", batch_y)
print("predict_y: ", predict_y)


def before_backward(self):
def before_backward(self, loss, model):
print("before_backward") print("before_backward")
print("loss=", loss)
print("model: ", model)


def after_batch(self): def after_batch(self):
print("after_batch") print("after_batch")


def after_epoch(self):
def after_epoch(self, cur_epoch, n_epoch, optimizer):
print("after_epoch") print("after_epoch")
print("cur_epoch: ", cur_epoch)
print("n_epoch: ", n_epoch)
print("optimizer", optimizer)


def after_train(self):
def after_train(self, model):
print("after_train") print("after_train")
print("model: ", model)




if __name__ == "__main__": if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train(10, 11, 12)
# print(manager.after_epoch())
manager.after_epoch(3, 10, "optimizer")

+ 8
- 5
fastNLP/core/trainer.py View File

@@ -296,23 +296,26 @@ class Trainer(object):
epoch = 1 epoch = 1
start = time.time() start = time.time()
while epoch <= self.n_epochs: while epoch <= self.n_epochs:
self.callback_manager.before_epoch()
self.callback_manager.before_epoch(epoch, self.n_epochs)


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)


for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
self.callback_manager.before_batch()
self.callback_manager.before_batch(batch_x, batch_y, data_iterator.get_batch_indices())
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)


self.callback_manager.before_loss()
self.callback_manager.before_loss(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)


self.callback_manager.before_backward()
self.callback_manager.before_backward(loss, self.model)
self._grad_backward(loss) self._grad_backward(loss)

self.callback_manager.after_backward(self.model)
self._update() self._update()
self.callback_manager.after_step()


self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
@@ -338,7 +341,7 @@ class Trainer(object):
if self.dev_data and self.validate_every <= 0: if self.dev_data and self.validate_every <= 0:
self._do_validation(epoch=epoch, step=self.step) self._do_validation(epoch=epoch, step=self.step)
epoch += 1 epoch += 1
self.callback_manager.after_epoch()
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)


def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
res = self.tester.test() res = self.tester.test()


Loading…
Cancel
Save