Browse Source

add callback in Trainer

tags/v0.3.0^2
FengZiYjun 5 years ago
parent
commit
179d12327a
3 changed files with 231 additions and 1 deletions
  1. +161
    -0
      fastNLP/core/callback.py
  2. +26
    -1
      fastNLP/core/trainer.py
  3. +44
    -0
      test/core/test_callbacks.py

+ 161
- 0
fastNLP/core/callback.py View File

@@ -0,0 +1,161 @@
class Callback(object):
"""An Interface for all callbacks.

Any customized callback should implement at least one of the following methods.

"""

def __init__(self):
super(Callback, self).__init__()

def before_train(self):
# before the main training loop
pass

def before_epoch(self):
# at the beginning of each epoch
pass

def before_batch(self):
# at the beginning of each step/mini-batch
pass

def before_loss(self):
# after data_forward, and before loss computation
pass

def before_backward(self):
# after loss computation, and before gradient backward
pass

def after_batch(self):
# at the end of each step/mini-batch
pass

def after_epoch(self):
# at the end of each epoch
pass

def after_train(self):
# after training loop
pass


def transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类.

:param func:
:return:
"""

def wrapper(manager):
returns = []
for callback in manager.callbacks:
for env_name, env_value in manager.env.items():
setattr(callback, env_name, env_value)
returns.append(getattr(callback, func.__name__)())
return returns

return wrapper


class CallbackManager(Callback):
"""A manager for all callbacks passed into Trainer.
It collects resources inside Trainer and raise callbacks.

"""

def __init__(self, env, callbacks=None):
"""

:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself.
:param Callback callbacks:
"""
super(CallbackManager, self).__init__()
# set attribute of trainer environment
self.env = env

self.callbacks = []
if callbacks is not None:
if isinstance(callbacks, list):
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
self.callbacks.extend(callbacks)
else:
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
else:
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")

@transfer
def before_train(self):
pass

@transfer
def before_epoch(self):
pass

@transfer
def before_batch(self):
pass

@transfer
def before_loss(self):
pass

@transfer
def before_backward(self):
pass

@transfer
def after_batch(self):
pass

@transfer
def after_epoch(self):
pass

@transfer
def after_train(self):
pass


class DummyCallback(Callback):
def before_train(self):
print("before train!!!")
print(self.n_epoch)

def after_epoch(self):
print("after epoch!!!")
return 12


class EchoCallback(Callback):
def before_train(self):
print("before_train")

def before_epoch(self):
print("before_epoch")

def before_batch(self):
print("before_batch")

def before_loss(self):
print("before_loss")

def before_backward(self):
print("before_backward")

def after_batch(self):
print("after_batch")

def after_epoch(self):
print("after_epoch")

def after_train(self):
print("after_train")


if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train()
print(manager.after_epoch())

+ 26
- 1
fastNLP/core/trainer.py View File

@@ -10,6 +10,7 @@ from torch import nn
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
@@ -29,7 +30,8 @@ from fastNLP.core.utils import get_func_signature
class Trainer(object): class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False):
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False,
callbacks=None):
""" """
:param DataSet train_data: the training data :param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model :param torch.nn.modules.module model: a PyTorch model
@@ -109,6 +111,7 @@ class Trainer(object):
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self.best_metric_indicator = None self.best_metric_indicator = None
self.sampler = sampler self.sampler = sampler
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)


if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
@@ -194,10 +197,14 @@ class Trainer(object):
else: else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)

self.callback_manager.before_train()
if self.use_tqdm: if self.use_tqdm:
self._tqdm_train() self._tqdm_train()
else: else:
self._print_train() self._print_train()
self.callback_manager.after_train()

if self.dev_data is not None: if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf),) self.tester._format_eval_results(self.best_dev_perf),)
@@ -227,11 +234,17 @@ class Trainer(object):
avg_loss = 0 avg_loss = 0
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
self.callback_manager.before_epoch()
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
self.callback_manager.before_batch()
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)

self.callback_manager.before_loss()
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)
avg_loss += loss.item() avg_loss += loss.item()

self.callback_manager.before_backward()
self._grad_backward(loss) self._grad_backward(loss)
self._update() self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
@@ -245,6 +258,8 @@ class Trainer(object):
avg_loss = 0 avg_loss = 0
pbar.update(self.print_every) pbar.update(self.print_every)
self.step += 1 self.step += 1
self.callback_manager.after_batch()

if self.validate_every > 0 and self.step % self.validate_every == 0 \ if self.validate_every > 0 and self.step % self.validate_every == 0 \
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
@@ -259,23 +274,31 @@ class Trainer(object):
if epoch!=self.n_epochs: if epoch!=self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)
self.callback_manager.after_epoch()
pbar.close() pbar.close()


def _print_train(self): def _print_train(self):
epoch = 1 epoch = 1
start = time.time() start = time.time()
while epoch <= self.n_epochs: while epoch <= self.n_epochs:
self.callback_manager.before_epoch()


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)


for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
self.callback_manager.before_batch()
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)

self.callback_manager.before_loss()
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)

self.callback_manager.before_backward()
self._grad_backward(loss) self._grad_backward(loss)
self._update() self._update()

self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if param.requires_grad: if param.requires_grad:
@@ -294,11 +317,13 @@ class Trainer(object):
self._do_validation(epoch=epoch, step=self.step) self._do_validation(epoch=epoch, step=self.step)


self.step += 1 self.step += 1
self.callback_manager.after_batch()


# validate_every override validation at end of epochs # validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0: if self.dev_data and self.validate_every <= 0:
self._do_validation(epoch=epoch, step=self.step) self._do_validation(epoch=epoch, step=self.step)
epoch += 1 epoch += 1
self.callback_manager.after_epoch()


def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
res = self.tester.test() res = self.tester.test()


+ 44
- 0
test/core/test_callbacks.py View File

@@ -0,0 +1,44 @@
import unittest

import numpy as np

from fastNLP.core.callback import EchoCallback
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer
from fastNLP.models.base_model import NaiveClassifier


class TestCallback(unittest.TestCase):
def test_case(self):
def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set

data_set = prepare_fake_dataset()
data_set.set_input("x")
data_set.set_target("y")

model = NaiveClassifier(2, 1)

trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=1,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
callbacks=[EchoCallback()])
trainer.train()

Loading…
Cancel
Save