| @@ -0,0 +1,161 @@ | |||
| class Callback(object): | |||
| """An Interface for all callbacks. | |||
| Any customized callback should implement at least one of the following methods. | |||
| """ | |||
| def __init__(self): | |||
| super(Callback, self).__init__() | |||
| def before_train(self): | |||
| # before the main training loop | |||
| pass | |||
| def before_epoch(self): | |||
| # at the beginning of each epoch | |||
| pass | |||
| def before_batch(self): | |||
| # at the beginning of each step/mini-batch | |||
| pass | |||
| def before_loss(self): | |||
| # after data_forward, and before loss computation | |||
| pass | |||
| def before_backward(self): | |||
| # after loss computation, and before gradient backward | |||
| pass | |||
| def after_batch(self): | |||
| # at the end of each step/mini-batch | |||
| pass | |||
| def after_epoch(self): | |||
| # at the end of each epoch | |||
| pass | |||
| def after_train(self): | |||
| # after training loop | |||
| pass | |||
| def transfer(func): | |||
| """装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||
| :param func: | |||
| :return: | |||
| """ | |||
| def wrapper(manager): | |||
| returns = [] | |||
| for callback in manager.callbacks: | |||
| for env_name, env_value in manager.env.items(): | |||
| setattr(callback, env_name, env_value) | |||
| returns.append(getattr(callback, func.__name__)()) | |||
| return returns | |||
| return wrapper | |||
| class CallbackManager(Callback): | |||
| """A manager for all callbacks passed into Trainer. | |||
| It collects resources inside Trainer and raise callbacks. | |||
| """ | |||
| def __init__(self, env, callbacks=None): | |||
| """ | |||
| :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
| :param Callback callbacks: | |||
| """ | |||
| super(CallbackManager, self).__init__() | |||
| # set attribute of trainer environment | |||
| self.env = env | |||
| self.callbacks = [] | |||
| if callbacks is not None: | |||
| if isinstance(callbacks, list): | |||
| if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
| self.callbacks.extend(callbacks) | |||
| else: | |||
| obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
| raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
| else: | |||
| raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
| @transfer | |||
| def before_train(self): | |||
| pass | |||
| @transfer | |||
| def before_epoch(self): | |||
| pass | |||
| @transfer | |||
| def before_batch(self): | |||
| pass | |||
| @transfer | |||
| def before_loss(self): | |||
| pass | |||
| @transfer | |||
| def before_backward(self): | |||
| pass | |||
| @transfer | |||
| def after_batch(self): | |||
| pass | |||
| @transfer | |||
| def after_epoch(self): | |||
| pass | |||
| @transfer | |||
| def after_train(self): | |||
| pass | |||
| class DummyCallback(Callback): | |||
| def before_train(self): | |||
| print("before train!!!") | |||
| print(self.n_epoch) | |||
| def after_epoch(self): | |||
| print("after epoch!!!") | |||
| return 12 | |||
| class EchoCallback(Callback): | |||
| def before_train(self): | |||
| print("before_train") | |||
| def before_epoch(self): | |||
| print("before_epoch") | |||
| def before_batch(self): | |||
| print("before_batch") | |||
| def before_loss(self): | |||
| print("before_loss") | |||
| def before_backward(self): | |||
| print("before_backward") | |||
| def after_batch(self): | |||
| print("after_batch") | |||
| def after_epoch(self): | |||
| print("after_epoch") | |||
| def after_train(self): | |||
| print("after_train") | |||
| if __name__ == "__main__": | |||
| manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | |||
| manager.before_train() | |||
| print(manager.after_epoch()) | |||
| @@ -10,6 +10,7 @@ from torch import nn | |||
| from tqdm.autonotebook import tqdm | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.callback import CallbackManager | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.losses import _prepare_losser | |||
| from fastNLP.core.metrics import _prepare_metrics | |||
| @@ -29,7 +30,8 @@ from fastNLP.core.utils import get_func_signature | |||
| class Trainer(object): | |||
| def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
| validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||
| check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): | |||
| check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, | |||
| callbacks=None): | |||
| """ | |||
| :param DataSet train_data: the training data | |||
| :param torch.nn.modules.module model: a PyTorch model | |||
| @@ -109,6 +111,7 @@ class Trainer(object): | |||
| self.validate_every = int(validate_every) | |||
| self.best_metric_indicator = None | |||
| self.sampler = sampler | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
| if isinstance(optimizer, torch.optim.Optimizer): | |||
| self.optimizer = optimizer | |||
| @@ -194,10 +197,14 @@ class Trainer(object): | |||
| else: | |||
| path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | |||
| self._summary_writer = SummaryWriter(path) | |||
| self.callback_manager.before_train() | |||
| if self.use_tqdm: | |||
| self._tqdm_train() | |||
| else: | |||
| self._print_train() | |||
| self.callback_manager.after_train() | |||
| if self.dev_data is not None: | |||
| print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
| self.tester._format_eval_results(self.best_dev_perf),) | |||
| @@ -227,11 +234,17 @@ class Trainer(object): | |||
| avg_loss = 0 | |||
| for epoch in range(1, self.n_epochs+1): | |||
| pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
| self.callback_manager.before_epoch() | |||
| for batch_x, batch_y in data_iterator: | |||
| self.callback_manager.before_batch() | |||
| _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
| prediction = self._data_forward(self.model, batch_x) | |||
| self.callback_manager.before_loss() | |||
| loss = self._compute_loss(prediction, batch_y) | |||
| avg_loss += loss.item() | |||
| self.callback_manager.before_backward() | |||
| self._grad_backward(loss) | |||
| self._update() | |||
| self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||
| @@ -245,6 +258,8 @@ class Trainer(object): | |||
| avg_loss = 0 | |||
| pbar.update(self.print_every) | |||
| self.step += 1 | |||
| self.callback_manager.after_batch() | |||
| if self.validate_every > 0 and self.step % self.validate_every == 0 \ | |||
| and self.dev_data is not None: | |||
| eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
| @@ -259,23 +274,31 @@ class Trainer(object): | |||
| if epoch!=self.n_epochs: | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||
| as_numpy=False) | |||
| self.callback_manager.after_epoch() | |||
| pbar.close() | |||
| def _print_train(self): | |||
| epoch = 1 | |||
| start = time.time() | |||
| while epoch <= self.n_epochs: | |||
| self.callback_manager.before_epoch() | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||
| as_numpy=False) | |||
| for batch_x, batch_y in data_iterator: | |||
| self.callback_manager.before_batch() | |||
| # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||
| _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
| prediction = self._data_forward(self.model, batch_x) | |||
| self.callback_manager.before_loss() | |||
| loss = self._compute_loss(prediction, batch_y) | |||
| self.callback_manager.before_backward() | |||
| self._grad_backward(loss) | |||
| self._update() | |||
| self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||
| for name, param in self.model.named_parameters(): | |||
| if param.requires_grad: | |||
| @@ -294,11 +317,13 @@ class Trainer(object): | |||
| self._do_validation(epoch=epoch, step=self.step) | |||
| self.step += 1 | |||
| self.callback_manager.after_batch() | |||
| # validate_every override validation at end of epochs | |||
| if self.dev_data and self.validate_every <= 0: | |||
| self._do_validation(epoch=epoch, step=self.step) | |||
| epoch += 1 | |||
| self.callback_manager.after_epoch() | |||
| def _do_validation(self, epoch, step): | |||
| res = self.tester.test() | |||
| @@ -0,0 +1,44 @@ | |||
| import unittest | |||
| import numpy as np | |||
| from fastNLP.core.callback import EchoCallback | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.losses import BCELoss | |||
| from fastNLP.core.optimizer import SGD | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.models.base_model import NaiveClassifier | |||
| class TestCallback(unittest.TestCase): | |||
| def test_case(self): | |||
| def prepare_fake_dataset(): | |||
| mean = np.array([-3, -3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| mean = np.array([3, 3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
| [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
| return data_set | |||
| data_set = prepare_fake_dataset() | |||
| data_set.set_input("x") | |||
| data_set.set_target("y") | |||
| model = NaiveClassifier(2, 1) | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=1, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| callbacks=[EchoCallback()]) | |||
| trainer.train() | |||