@@ -0,0 +1,161 @@ | |||
class Callback(object): | |||
"""An Interface for all callbacks. | |||
Any customized callback should implement at least one of the following methods. | |||
""" | |||
def __init__(self): | |||
super(Callback, self).__init__() | |||
def before_train(self): | |||
# before the main training loop | |||
pass | |||
def before_epoch(self): | |||
# at the beginning of each epoch | |||
pass | |||
def before_batch(self): | |||
# at the beginning of each step/mini-batch | |||
pass | |||
def before_loss(self): | |||
# after data_forward, and before loss computation | |||
pass | |||
def before_backward(self): | |||
# after loss computation, and before gradient backward | |||
pass | |||
def after_batch(self): | |||
# at the end of each step/mini-batch | |||
pass | |||
def after_epoch(self): | |||
# at the end of each epoch | |||
pass | |||
def after_train(self): | |||
# after training loop | |||
pass | |||
def transfer(func): | |||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||
:param func: | |||
:return: | |||
""" | |||
def wrapper(manager): | |||
returns = [] | |||
for callback in manager.callbacks: | |||
for env_name, env_value in manager.env.items(): | |||
setattr(callback, env_name, env_value) | |||
returns.append(getattr(callback, func.__name__)()) | |||
return returns | |||
return wrapper | |||
class CallbackManager(Callback): | |||
"""A manager for all callbacks passed into Trainer. | |||
It collects resources inside Trainer and raise callbacks. | |||
""" | |||
def __init__(self, env, callbacks=None): | |||
""" | |||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
:param Callback callbacks: | |||
""" | |||
super(CallbackManager, self).__init__() | |||
# set attribute of trainer environment | |||
self.env = env | |||
self.callbacks = [] | |||
if callbacks is not None: | |||
if isinstance(callbacks, list): | |||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
self.callbacks.extend(callbacks) | |||
else: | |||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
@transfer | |||
def before_train(self): | |||
pass | |||
@transfer | |||
def before_epoch(self): | |||
pass | |||
@transfer | |||
def before_batch(self): | |||
pass | |||
@transfer | |||
def before_loss(self): | |||
pass | |||
@transfer | |||
def before_backward(self): | |||
pass | |||
@transfer | |||
def after_batch(self): | |||
pass | |||
@transfer | |||
def after_epoch(self): | |||
pass | |||
@transfer | |||
def after_train(self): | |||
pass | |||
class DummyCallback(Callback): | |||
def before_train(self): | |||
print("before train!!!") | |||
print(self.n_epoch) | |||
def after_epoch(self): | |||
print("after epoch!!!") | |||
return 12 | |||
class EchoCallback(Callback): | |||
def before_train(self): | |||
print("before_train") | |||
def before_epoch(self): | |||
print("before_epoch") | |||
def before_batch(self): | |||
print("before_batch") | |||
def before_loss(self): | |||
print("before_loss") | |||
def before_backward(self): | |||
print("before_backward") | |||
def after_batch(self): | |||
print("after_batch") | |||
def after_epoch(self): | |||
print("after_epoch") | |||
def after_train(self): | |||
print("after_train") | |||
if __name__ == "__main__": | |||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | |||
manager.before_train() | |||
print(manager.after_epoch()) |
@@ -10,6 +10,7 @@ from torch import nn | |||
from tqdm.autonotebook import tqdm | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.callback import CallbackManager | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.losses import _prepare_losser | |||
from fastNLP.core.metrics import _prepare_metrics | |||
@@ -29,7 +30,8 @@ from fastNLP.core.utils import get_func_signature | |||
class Trainer(object): | |||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): | |||
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, | |||
callbacks=None): | |||
""" | |||
:param DataSet train_data: the training data | |||
:param torch.nn.modules.module model: a PyTorch model | |||
@@ -109,6 +111,7 @@ class Trainer(object): | |||
self.validate_every = int(validate_every) | |||
self.best_metric_indicator = None | |||
self.sampler = sampler | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
@@ -194,10 +197,14 @@ class Trainer(object): | |||
else: | |||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | |||
self._summary_writer = SummaryWriter(path) | |||
self.callback_manager.before_train() | |||
if self.use_tqdm: | |||
self._tqdm_train() | |||
else: | |||
self._print_train() | |||
self.callback_manager.after_train() | |||
if self.dev_data is not None: | |||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf),) | |||
@@ -227,11 +234,17 @@ class Trainer(object): | |||
avg_loss = 0 | |||
for epoch in range(1, self.n_epochs+1): | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
self.callback_manager.before_epoch() | |||
for batch_x, batch_y in data_iterator: | |||
self.callback_manager.before_batch() | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
prediction = self._data_forward(self.model, batch_x) | |||
self.callback_manager.before_loss() | |||
loss = self._compute_loss(prediction, batch_y) | |||
avg_loss += loss.item() | |||
self.callback_manager.before_backward() | |||
self._grad_backward(loss) | |||
self._update() | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||
@@ -245,6 +258,8 @@ class Trainer(object): | |||
avg_loss = 0 | |||
pbar.update(self.print_every) | |||
self.step += 1 | |||
self.callback_manager.after_batch() | |||
if self.validate_every > 0 and self.step % self.validate_every == 0 \ | |||
and self.dev_data is not None: | |||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
@@ -259,23 +274,31 @@ class Trainer(object): | |||
if epoch!=self.n_epochs: | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||
as_numpy=False) | |||
self.callback_manager.after_epoch() | |||
pbar.close() | |||
def _print_train(self): | |||
epoch = 1 | |||
start = time.time() | |||
while epoch <= self.n_epochs: | |||
self.callback_manager.before_epoch() | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||
as_numpy=False) | |||
for batch_x, batch_y in data_iterator: | |||
self.callback_manager.before_batch() | |||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
prediction = self._data_forward(self.model, batch_x) | |||
self.callback_manager.before_loss() | |||
loss = self._compute_loss(prediction, batch_y) | |||
self.callback_manager.before_backward() | |||
self._grad_backward(loss) | |||
self._update() | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||
for name, param in self.model.named_parameters(): | |||
if param.requires_grad: | |||
@@ -294,11 +317,13 @@ class Trainer(object): | |||
self._do_validation(epoch=epoch, step=self.step) | |||
self.step += 1 | |||
self.callback_manager.after_batch() | |||
# validate_every override validation at end of epochs | |||
if self.dev_data and self.validate_every <= 0: | |||
self._do_validation(epoch=epoch, step=self.step) | |||
epoch += 1 | |||
self.callback_manager.after_epoch() | |||
def _do_validation(self, epoch, step): | |||
res = self.tester.test() | |||
@@ -0,0 +1,44 @@ | |||
import unittest | |||
import numpy as np | |||
from fastNLP.core.callback import EchoCallback | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
from fastNLP.core.optimizer import SGD | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier | |||
class TestCallback(unittest.TestCase): | |||
def test_case(self): | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
model = NaiveClassifier(2, 1) | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=1, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
callbacks=[EchoCallback()]) | |||
trainer.train() |