diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py new file mode 100644 index 00000000..b172f3a4 --- /dev/null +++ b/fastNLP/core/callback.py @@ -0,0 +1,161 @@ +class Callback(object): + """An Interface for all callbacks. + + Any customized callback should implement at least one of the following methods. + + """ + + def __init__(self): + super(Callback, self).__init__() + + def before_train(self): + # before the main training loop + pass + + def before_epoch(self): + # at the beginning of each epoch + pass + + def before_batch(self): + # at the beginning of each step/mini-batch + pass + + def before_loss(self): + # after data_forward, and before loss computation + pass + + def before_backward(self): + # after loss computation, and before gradient backward + pass + + def after_batch(self): + # at the end of each step/mini-batch + pass + + def after_epoch(self): + # at the end of each epoch + pass + + def after_train(self): + # after training loop + pass + + +def transfer(func): + """装饰器,将对CallbackManager的调用转发到各个Callback子类. + + :param func: + :return: + """ + + def wrapper(manager): + returns = [] + for callback in manager.callbacks: + for env_name, env_value in manager.env.items(): + setattr(callback, env_name, env_value) + returns.append(getattr(callback, func.__name__)()) + return returns + + return wrapper + + +class CallbackManager(Callback): + """A manager for all callbacks passed into Trainer. + It collects resources inside Trainer and raise callbacks. + + """ + + def __init__(self, env, callbacks=None): + """ + + :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. + :param Callback callbacks: + """ + super(CallbackManager, self).__init__() + # set attribute of trainer environment + self.env = env + + self.callbacks = [] + if callbacks is not None: + if isinstance(callbacks, list): + if all([isinstance(cb, Callback) for cb in callbacks]) is True: + self.callbacks.extend(callbacks) + else: + obj = [not isinstance(cb, Callback) for cb in callbacks][0] + raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") + else: + raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") + + @transfer + def before_train(self): + pass + + @transfer + def before_epoch(self): + pass + + @transfer + def before_batch(self): + pass + + @transfer + def before_loss(self): + pass + + @transfer + def before_backward(self): + pass + + @transfer + def after_batch(self): + pass + + @transfer + def after_epoch(self): + pass + + @transfer + def after_train(self): + pass + + +class DummyCallback(Callback): + def before_train(self): + print("before train!!!") + print(self.n_epoch) + + def after_epoch(self): + print("after epoch!!!") + return 12 + + +class EchoCallback(Callback): + def before_train(self): + print("before_train") + + def before_epoch(self): + print("before_epoch") + + def before_batch(self): + print("before_batch") + + def before_loss(self): + print("before_loss") + + def before_backward(self): + print("before_backward") + + def after_batch(self): + print("after_batch") + + def after_epoch(self): + print("after_epoch") + + def after_train(self): + print("after_train") + + +if __name__ == "__main__": + manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) + manager.before_train() + print(manager.after_epoch()) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index c1bb4ec9..db0be67f 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -10,6 +10,7 @@ from torch import nn from tqdm.autonotebook import tqdm from fastNLP.core.batch import Batch +from fastNLP.core.callback import CallbackManager from fastNLP.core.dataset import DataSet from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics @@ -29,7 +30,8 @@ from fastNLP.core.utils import get_func_signature class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), - check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): + check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, + callbacks=None): """ :param DataSet train_data: the training data :param torch.nn.modules.module model: a PyTorch model @@ -109,6 +111,7 @@ class Trainer(object): self.validate_every = int(validate_every) self.best_metric_indicator = None self.sampler = sampler + self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer @@ -194,10 +197,14 @@ class Trainer(object): else: path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) self._summary_writer = SummaryWriter(path) + + self.callback_manager.before_train() if self.use_tqdm: self._tqdm_train() else: self._print_train() + self.callback_manager.after_train() + if self.dev_data is not None: print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + self.tester._format_eval_results(self.best_dev_perf),) @@ -227,11 +234,17 @@ class Trainer(object): avg_loss = 0 for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) + self.callback_manager.before_epoch() for batch_x, batch_y in data_iterator: + self.callback_manager.before_batch() _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) prediction = self._data_forward(self.model, batch_x) + + self.callback_manager.before_loss() loss = self._compute_loss(prediction, batch_y) avg_loss += loss.item() + + self.callback_manager.before_backward() self._grad_backward(loss) self._update() self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) @@ -245,6 +258,8 @@ class Trainer(object): avg_loss = 0 pbar.update(self.print_every) self.step += 1 + self.callback_manager.after_batch() + if self.validate_every > 0 and self.step % self.validate_every == 0 \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) @@ -259,23 +274,31 @@ class Trainer(object): if epoch!=self.n_epochs: data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) + self.callback_manager.after_epoch() pbar.close() def _print_train(self): epoch = 1 start = time.time() while epoch <= self.n_epochs: + self.callback_manager.before_epoch() data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) for batch_x, batch_y in data_iterator: + self.callback_manager.before_batch() # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) prediction = self._data_forward(self.model, batch_x) + + self.callback_manager.before_loss() loss = self._compute_loss(prediction, batch_y) + + self.callback_manager.before_backward() self._grad_backward(loss) self._update() + self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) for name, param in self.model.named_parameters(): if param.requires_grad: @@ -294,11 +317,13 @@ class Trainer(object): self._do_validation(epoch=epoch, step=self.step) self.step += 1 + self.callback_manager.after_batch() # validate_every override validation at end of epochs if self.dev_data and self.validate_every <= 0: self._do_validation(epoch=epoch, step=self.step) epoch += 1 + self.callback_manager.after_epoch() def _do_validation(self, epoch, step): res = self.tester.test() diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py new file mode 100644 index 00000000..20822cde --- /dev/null +++ b/test/core/test_callbacks.py @@ -0,0 +1,44 @@ +import unittest + +import numpy as np + +from fastNLP.core.callback import EchoCallback +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance +from fastNLP.core.losses import BCELoss +from fastNLP.core.optimizer import SGD +from fastNLP.core.trainer import Trainer +from fastNLP.models.base_model import NaiveClassifier + + +class TestCallback(unittest.TestCase): + def test_case(self): + def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + return data_set + + data_set = prepare_fake_dataset() + data_set.set_input("x") + data_set.set_target("y") + + model = NaiveClassifier(2, 1) + + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=1, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + callbacks=[EchoCallback()]) + trainer.train()