Browse Source

update callbacks:

* rename callback methods. Use fastai's notation.
* add a new callback method - on_valid_begin
tags/v0.3.1^2
FengZiYjun 6 years ago
parent
commit
887fc9281f
5 changed files with 131 additions and 94 deletions
  1. +52
    -45
      fastNLP/core/callback.py
  2. +45
    -23
      fastNLP/core/predictor.py
  3. +12
    -14
      fastNLP/core/trainer.py
  4. +1
    -6
      reproduction/Biaffine_parser/run.py
  5. +21
    -6
      test/core/test_predictor.py

+ 52
- 45
fastNLP/core/callback.py View File

@@ -17,37 +17,40 @@ class Callback(object):
super(Callback, self).__init__() super(Callback, self).__init__()
self.trainer = None # 在Trainer内部被重新赋值 self.trainer = None # 在Trainer内部被重新赋值


def before_train(self):
def on_train_begin(self):
# before the main training loop # before the main training loop
pass pass


def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
# at the beginning of each epoch # at the beginning of each epoch
pass pass


def before_batch(self, batch_x, batch_y, indices):
def on_batch_begin(self, batch_x, batch_y, indices):
# at the beginning of each step/mini-batch # at the beginning of each step/mini-batch
pass pass


def before_loss(self, batch_y, predict_y):
def on_loss_begin(self, batch_y, predict_y):
# after data_forward, and before loss computation # after data_forward, and before loss computation
pass pass


def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
# after loss computation, and before gradient backward # after loss computation, and before gradient backward
pass pass


def after_backward(self, model):
def on_backward_end(self, model):
pass pass


def after_step(self, optimizer):
def on_step_end(self, optimizer):
pass pass


def after_batch(self, *args):
def on_batch_end(self, *args):
# at the end of each step/mini-batch # at the end of each step/mini-batch
pass pass


def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_begin(self):
pass

def on_valid_end(self, eval_result, metric_key, optimizer):
""" """
每次执行验证机的evaluation后会调用。传入eval_result 每次执行验证机的evaluation后会调用。传入eval_result


@@ -58,7 +61,7 @@ class Callback(object):
""" """
pass pass


def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
""" """
每个epoch结束将会调用该方法 每个epoch结束将会调用该方法


@@ -69,7 +72,7 @@ class Callback(object):
""" """
pass pass


def after_train(self, model):
def on_train_end(self, model):
""" """
训练结束,调用该方法 训练结束,调用该方法


@@ -134,47 +137,51 @@ class CallbackManager(Callback):
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")


@transfer @transfer
def before_train(self):
def on_train_begin(self):
pass

@transfer
def on_epoch_begin(self, cur_epoch, total_epoch):
pass pass


@transfer @transfer
def before_epoch(self, cur_epoch, total_epoch):
def on_batch_begin(self, batch_x, batch_y, indices):
pass pass


@transfer @transfer
def before_batch(self, batch_x, batch_y, indices):
def on_loss_begin(self, batch_y, predict_y):
pass pass


@transfer @transfer
def before_loss(self, batch_y, predict_y):
def on_backward_begin(self, loss, model):
pass pass


@transfer @transfer
def before_backward(self, loss, model):
def on_backward_end(self, model):
pass pass


@transfer @transfer
def after_backward(self, model):
def on_step_end(self, optimizer):
pass pass


@transfer @transfer
def after_step(self, optimizer):
def on_batch_end(self):
pass pass


@transfer @transfer
def after_batch(self):
def on_valid_begin(self):
pass pass


@transfer @transfer
def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer):
pass pass


@transfer @transfer
def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
pass pass


@transfer @transfer
def after_train(self, model):
def on_train_end(self, model):
pass pass


@transfer @transfer
@@ -183,36 +190,36 @@ class CallbackManager(Callback):




class DummyCallback(Callback): class DummyCallback(Callback):
def before_train(self, *arg):
def on_train_begin(self, *arg):
print(arg) print(arg)


def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
print(cur_epoch, n_epoch, optimizer) print(cur_epoch, n_epoch, optimizer)




class EchoCallback(Callback): class EchoCallback(Callback):
def before_train(self):
def on_train_begin(self):
print("before_train") print("before_train")


def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
print("before_epoch") print("before_epoch")


def before_batch(self, batch_x, batch_y, indices):
def on_batch_begin(self, batch_x, batch_y, indices):
print("before_batch") print("before_batch")


def before_loss(self, batch_y, predict_y):
def on_loss_begin(self, batch_y, predict_y):
print("before_loss") print("before_loss")


def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
print("before_backward") print("before_backward")


def after_batch(self):
def on_batch_end(self):
print("after_batch") print("after_batch")


def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
print("after_epoch") print("after_epoch")


def after_train(self, model):
def on_train_end(self, model):
print("after_train") print("after_train")




@@ -240,7 +247,7 @@ class GradientClipCallback(Callback):
self.parameters = parameters self.parameters = parameters
self.clip_value = clip_value self.clip_value = clip_value


def after_backward(self, model):
def on_backward_end(self, model):
self.clip_fun(model.parameters(), self.clip_value) self.clip_fun(model.parameters(), self.clip_value)




@@ -266,7 +273,7 @@ class EarlyStopCallback(Callback):
self.wait = 0 self.wait = 0
self.epoch = 0 self.epoch = 0


def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer):
self.epoch += 1 self.epoch += 1
if not self.trainer._better_eval_result(eval_result): if not self.trainer._better_eval_result(eval_result):
# current result is getting worse # current result is getting worse
@@ -297,7 +304,7 @@ class LRScheduler(Callback):
else: else:
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.")


def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
self.scheduler.step() self.scheduler.step()
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"])


@@ -359,7 +366,7 @@ class LRFinder(Callback):
self.find = None self.find = None
self.loader = ModelLoader() self.loader = ModelLoader()


def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
if cur_epoch == 1: if cur_epoch == 1:
self.opt = self.trainer.optimizer # pytorch optimizer self.opt = self.trainer.optimizer # pytorch optimizer
self.opt.param_groups[0]["lr"] = self.start_lr self.opt.param_groups[0]["lr"] = self.start_lr
@@ -367,7 +374,7 @@ class LRFinder(Callback):
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
self.find = True self.find = True


def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
if self.find: if self.find:
if torch.isnan(loss) or self.stop is True: if torch.isnan(loss) or self.stop is True:
self.stop = True self.stop = True
@@ -379,7 +386,7 @@ class LRFinder(Callback):
self.best_loss = self.smooth_value.smooth self.best_loss = self.smooth_value.smooth
self.best_lr = self.opt.param_groups[0]["lr"] self.best_lr = self.opt.param_groups[0]["lr"]


def after_batch(self, *args):
def on_batch_end(self, *args):
if self.find: if self.find:
lr = next(self.lr_gen, None) lr = next(self.lr_gen, None)
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss:
@@ -388,7 +395,7 @@ class LRFinder(Callback):
self.opt.param_groups[0]["lr"] = lr self.opt.param_groups[0]["lr"] = lr
# self.loader.load_pytorch(self.trainer.model, "tmp") # self.loader.load_pytorch(self.trainer.model, "tmp")


def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
if cur_epoch == 1: if cur_epoch == 1:
self.opt.param_groups[0]["lr"] = self.best_lr self.opt.param_groups[0]["lr"] = self.best_lr
self.find = False self.find = False
@@ -415,7 +422,7 @@ class TensorboardCallback(Callback):
self._summary_writer = None self._summary_writer = None
self.graph_added = False self.graph_added = False


def before_train(self):
def on_train_begin(self):
save_dir = self.trainer.save_path save_dir = self.trainer.save_path
if save_dir is None: if save_dir is None:
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time))
@@ -423,7 +430,7 @@ class TensorboardCallback(Callback):
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time))
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)


def before_batch(self, batch_x, batch_y, indices):
def on_batch_begin(self, batch_x, batch_y, indices):
if "model" in self.options and self.graph_added is False: if "model" in self.options and self.graph_added is False:
# tesorboardX 这里有大bug,暂时没法画模型图 # tesorboardX 这里有大bug,暂时没法画模型图
# from fastNLP.core.utils import _build_args # from fastNLP.core.utils import _build_args
@@ -433,7 +440,7 @@ class TensorboardCallback(Callback):
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2))
self.graph_added = True self.graph_added = True


def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
if "loss" in self.options: if "loss" in self.options:
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step)


@@ -445,14 +452,14 @@ class TensorboardCallback(Callback):
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(),
global_step=self.trainer.step) global_step=self.trainer.step)


def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer):
if "metric" in self.options: if "metric" in self.options:
for name, metric in eval_result.items(): for name, metric in eval_result.items():
for metric_key, metric_val in metric.items(): for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.trainer.step) global_step=self.trainer.step)


def after_train(self, model):
def on_train_end(self, model):
self._summary_writer.close() self._summary_writer.close()
del self._summary_writer del self._summary_writer


@@ -464,5 +471,5 @@ class TensorboardCallback(Callback):


if __name__ == "__main__": if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train(10, 11, 12)
manager.on_train_begin(10, 11, 12)
# print(manager.after_epoch()) # print(manager.after_epoch())

+ 45
- 23
fastNLP/core/predictor.py View File

@@ -1,7 +1,11 @@
from collections import defaultdict

import torch import torch


from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core import Batch
from fastNLP.core import DataSet
from fastNLP.core import SequentialSampler
from fastNLP.core.utils import _build_args




class Predictor(object): class Predictor(object):
@@ -13,37 +17,55 @@ class Predictor(object):
Currently, Predictor does not support GPU. Currently, Predictor does not support GPU.
""" """


def __init__(self):
def __init__(self, network):
if not isinstance(network, torch.nn.Module):
raise ValueError(
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network)))
self.network = network
self.batch_size = 1 self.batch_size = 1
self.batch_output = [] self.batch_output = []


def predict(self, network, data):
def predict(self, data, seq_len_field_name=None):
"""Perform inference using the trained model. """Perform inference using the trained model.


:param network: a PyTorch model (cpu)
:param data: a DataSet object. :param data: a DataSet object.
:param str seq_len_field_name: field name indicating sequence lengths
:return: list of batch outputs :return: list of batch outputs
""" """
# turn on the testing mode; clean up the history
self.mode(network, test=True)
batch_output = []
if not isinstance(data, DataSet):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))


data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
self.network.eval()
batch_output = defaultdict(list)
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False,
prefetch=False)


for batch_x, _ in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
batch_output.append(prediction)
if hasattr(self.network, "predict"):
predict_func = self.network.predict
else:
predict_func = self.network.forward


return batch_output
with torch.no_grad():
for batch_x, _ in data_iterator:
refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x)


def mode(self, network, test=True):
if test:
network.eval()
else:
network.train()
if seq_len_field_name is not None:
seq_lens = batch_x[seq_len_field_name].tolist()

for key, value in prediction.items():
value = value.cpu().numpy()
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
batch_output[key].extend(value.tolist())
else:
if seq_len_field_name is not None:
tmp_batch = []
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
else:
batch_output[key].append(value)


def data_forward(self, network, x):
"""Forward through network."""
y = network(**x)
return y
return batch_output

+ 12
- 14
fastNLP/core/trainer.py View File

@@ -196,9 +196,9 @@ class Trainer(object):
print("training epochs started " + self.start_time, flush=True) print("training epochs started " + self.start_time, flush=True)


try: try:
self.callback_manager.before_train()
self.callback_manager.on_train_begin()
self._train() self._train()
self.callback_manager.after_train(self.model)
self.callback_manager.on_train_end(self.model)
except (CallbackException, KeyboardInterrupt) as e: except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model) self.callback_manager.on_exception(e, self.model)


@@ -237,28 +237,26 @@ class Trainer(object):
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping # early stopping
self.callback_manager.before_epoch(epoch, self.n_epochs)
self.callback_manager.on_epoch_begin(epoch, self.n_epochs)
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices() indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices)
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)


# edit prediction # edit prediction
self.callback_manager.before_loss(batch_y, prediction)
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)
avg_loss += loss.item() avg_loss += loss.item()


# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.before_backward(loss, self.model)
self.callback_manager.on_backward_begin(loss, self.model)
self._grad_backward(loss) self._grad_backward(loss)
# gradient clipping
self.callback_manager.after_backward(self.model)
self.callback_manager.on_backward_end(self.model)


self._update() self._update()
# lr scheduler; lr_finder; one_cycle
self.callback_manager.after_step(self.optimizer)
self.callback_manager.on_step_end(self.optimizer)


if (self.step+1) % self.print_every == 0: if (self.step+1) % self.print_every == 0:
if self.use_tqdm: if self.use_tqdm:
@@ -272,8 +270,7 @@ class Trainer(object):
pbar.set_postfix_str(print_output) pbar.set_postfix_str(print_output)
avg_loss = 0 avg_loss = 0
self.step += 1 self.step += 1
# do nothing
self.callback_manager.after_batch()
self.callback_manager.on_batch_end()


if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
@@ -287,12 +284,13 @@ class Trainer(object):
# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #


# lr decay; early stopping # lr decay; early stopping
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
# =============== epochs end =================== # # =============== epochs end =================== #
pbar.close() pbar.close()
# ============ tqdm end ============== # # ============ tqdm end ============== #


def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
self.callback_manager.on_valid_begin()
res = self.tester.test() res = self.tester.test()


if self._better_eval_result(res): if self._better_eval_result(res):
@@ -305,7 +303,7 @@ class Trainer(object):
self.best_dev_epoch = epoch self.best_dev_epoch = epoch
self.best_dev_step = step self.best_dev_step = step
# get validation results; adjust optimizer # get validation results; adjust optimizer
self.callback_manager.after_valid(res, self.metric_key, self.optimizer)
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer)
return res return res


def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):


+ 1
- 6
reproduction/Biaffine_parser/run.py View File

@@ -4,19 +4,14 @@ import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


import fastNLP import fastNLP
import torch


from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline from fastNLP.api.pipeline import Pipeline
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.io.model_io import ModelLoader from fastNLP.io.model_io import ModelLoader
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.io.model_io import ModelSaver
from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.io.dataset_loader import ConllxDataLoader
from fastNLP.api.processor import * from fastNLP.api.processor import *
from fastNLP.io.embed_loader import EmbedLoader from fastNLP.io.embed_loader import EmbedLoader
@@ -172,7 +167,7 @@ def train(path):
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)


class MyCallback(Callback): class MyCallback(Callback):
def after_step(self, optimizer):
def on_step_end(self, optimizer):
step = self.trainer.step step = self.trainer.step
# learning rate decay # learning rate decay
if step > 0 and step % 1000 == 0: if step > 0 and step % 1000 == 0:


+ 21
- 6
test/core/test_predictor.py View File

@@ -1,4 +1,5 @@
import unittest import unittest
from collections import defaultdict


import numpy as np import numpy as np
import torch import torch
@@ -23,12 +24,26 @@ def prepare_fake_dataset():
return data_set return data_set




class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = Linear(2, 1)

def forward(self, x):
return {"predict": self.linear(x)}


class TestPredictor(unittest.TestCase): class TestPredictor(unittest.TestCase):
def test(self):
predictor = Predictor()
model = Linear(2, 1)
def test_simple(self):
model = LinearModel()
predictor = Predictor(model)
data = prepare_fake_dataset() data = prepare_fake_dataset()
data.set_input("x") data.set_input("x")
ans = predictor.predict(model, data)
self.assertEqual(len(ans), 2000)
self.assertTrue(isinstance(ans[0], torch.Tensor))
ans = predictor.predict(data)
self.assertTrue(isinstance(ans, defaultdict))
self.assertTrue("predict" in ans)
self.assertTrue(isinstance(ans["predict"], list))

def test_sequence(self):
# test sequence input/output
pass

Loading…
Cancel
Save