Browse Source

update callbacks:

* rename callback methods. Use fastai's notation.
* add a new callback method - on_valid_begin
tags/v0.3.1^2
FengZiYjun 6 years ago
parent
commit
887fc9281f
5 changed files with 131 additions and 94 deletions
  1. +52
    -45
      fastNLP/core/callback.py
  2. +45
    -23
      fastNLP/core/predictor.py
  3. +12
    -14
      fastNLP/core/trainer.py
  4. +1
    -6
      reproduction/Biaffine_parser/run.py
  5. +21
    -6
      test/core/test_predictor.py

+ 52
- 45
fastNLP/core/callback.py View File

@@ -17,37 +17,40 @@ class Callback(object):
super(Callback, self).__init__()
self.trainer = None # 在Trainer内部被重新赋值

def before_train(self):
def on_train_begin(self):
# before the main training loop
pass

def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
# at the beginning of each epoch
pass

def before_batch(self, batch_x, batch_y, indices):
def on_batch_begin(self, batch_x, batch_y, indices):
# at the beginning of each step/mini-batch
pass

def before_loss(self, batch_y, predict_y):
def on_loss_begin(self, batch_y, predict_y):
# after data_forward, and before loss computation
pass

def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
# after loss computation, and before gradient backward
pass

def after_backward(self, model):
def on_backward_end(self, model):
pass

def after_step(self, optimizer):
def on_step_end(self, optimizer):
pass

def after_batch(self, *args):
def on_batch_end(self, *args):
# at the end of each step/mini-batch
pass

def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_begin(self):
pass

def on_valid_end(self, eval_result, metric_key, optimizer):
"""
每次执行验证机的evaluation后会调用。传入eval_result

@@ -58,7 +61,7 @@ class Callback(object):
"""
pass

def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
"""
每个epoch结束将会调用该方法

@@ -69,7 +72,7 @@ class Callback(object):
"""
pass

def after_train(self, model):
def on_train_end(self, model):
"""
训练结束,调用该方法

@@ -134,47 +137,51 @@ class CallbackManager(Callback):
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")

@transfer
def before_train(self):
def on_train_begin(self):
pass

@transfer
def on_epoch_begin(self, cur_epoch, total_epoch):
pass

@transfer
def before_epoch(self, cur_epoch, total_epoch):
def on_batch_begin(self, batch_x, batch_y, indices):
pass

@transfer
def before_batch(self, batch_x, batch_y, indices):
def on_loss_begin(self, batch_y, predict_y):
pass

@transfer
def before_loss(self, batch_y, predict_y):
def on_backward_begin(self, loss, model):
pass

@transfer
def before_backward(self, loss, model):
def on_backward_end(self, model):
pass

@transfer
def after_backward(self, model):
def on_step_end(self, optimizer):
pass

@transfer
def after_step(self, optimizer):
def on_batch_end(self):
pass

@transfer
def after_batch(self):
def on_valid_begin(self):
pass

@transfer
def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer):
pass

@transfer
def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
pass

@transfer
def after_train(self, model):
def on_train_end(self, model):
pass

@transfer
@@ -183,36 +190,36 @@ class CallbackManager(Callback):


class DummyCallback(Callback):
def before_train(self, *arg):
def on_train_begin(self, *arg):
print(arg)

def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
print(cur_epoch, n_epoch, optimizer)


class EchoCallback(Callback):
def before_train(self):
def on_train_begin(self):
print("before_train")

def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
print("before_epoch")

def before_batch(self, batch_x, batch_y, indices):
def on_batch_begin(self, batch_x, batch_y, indices):
print("before_batch")

def before_loss(self, batch_y, predict_y):
def on_loss_begin(self, batch_y, predict_y):
print("before_loss")

def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
print("before_backward")

def after_batch(self):
def on_batch_end(self):
print("after_batch")

def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
print("after_epoch")

def after_train(self, model):
def on_train_end(self, model):
print("after_train")


@@ -240,7 +247,7 @@ class GradientClipCallback(Callback):
self.parameters = parameters
self.clip_value = clip_value

def after_backward(self, model):
def on_backward_end(self, model):
self.clip_fun(model.parameters(), self.clip_value)


@@ -266,7 +273,7 @@ class EarlyStopCallback(Callback):
self.wait = 0
self.epoch = 0

def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer):
self.epoch += 1
if not self.trainer._better_eval_result(eval_result):
# current result is getting worse
@@ -297,7 +304,7 @@ class LRScheduler(Callback):
else:
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.")

def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
self.scheduler.step()
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"])

@@ -359,7 +366,7 @@ class LRFinder(Callback):
self.find = None
self.loader = ModelLoader()

def before_epoch(self, cur_epoch, total_epoch):
def on_epoch_begin(self, cur_epoch, total_epoch):
if cur_epoch == 1:
self.opt = self.trainer.optimizer # pytorch optimizer
self.opt.param_groups[0]["lr"] = self.start_lr
@@ -367,7 +374,7 @@ class LRFinder(Callback):
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
self.find = True

def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
if self.find:
if torch.isnan(loss) or self.stop is True:
self.stop = True
@@ -379,7 +386,7 @@ class LRFinder(Callback):
self.best_loss = self.smooth_value.smooth
self.best_lr = self.opt.param_groups[0]["lr"]

def after_batch(self, *args):
def on_batch_end(self, *args):
if self.find:
lr = next(self.lr_gen, None)
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss:
@@ -388,7 +395,7 @@ class LRFinder(Callback):
self.opt.param_groups[0]["lr"] = lr
# self.loader.load_pytorch(self.trainer.model, "tmp")

def after_epoch(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
if cur_epoch == 1:
self.opt.param_groups[0]["lr"] = self.best_lr
self.find = False
@@ -415,7 +422,7 @@ class TensorboardCallback(Callback):
self._summary_writer = None
self.graph_added = False

def before_train(self):
def on_train_begin(self):
save_dir = self.trainer.save_path
if save_dir is None:
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time))
@@ -423,7 +430,7 @@ class TensorboardCallback(Callback):
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time))
self._summary_writer = SummaryWriter(path)

def before_batch(self, batch_x, batch_y, indices):
def on_batch_begin(self, batch_x, batch_y, indices):
if "model" in self.options and self.graph_added is False:
# tesorboardX 这里有大bug,暂时没法画模型图
# from fastNLP.core.utils import _build_args
@@ -433,7 +440,7 @@ class TensorboardCallback(Callback):
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2))
self.graph_added = True

def before_backward(self, loss, model):
def on_backward_begin(self, loss, model):
if "loss" in self.options:
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step)

@@ -445,14 +452,14 @@ class TensorboardCallback(Callback):
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(),
global_step=self.trainer.step)

def after_valid(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key, optimizer):
if "metric" in self.options:
for name, metric in eval_result.items():
for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.trainer.step)

def after_train(self, model):
def on_train_end(self, model):
self._summary_writer.close()
del self._summary_writer

@@ -464,5 +471,5 @@ class TensorboardCallback(Callback):

if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train(10, 11, 12)
manager.on_train_begin(10, 11, 12)
# print(manager.after_epoch())

+ 45
- 23
fastNLP/core/predictor.py View File

@@ -1,7 +1,11 @@
from collections import defaultdict

import torch

from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core import Batch
from fastNLP.core import DataSet
from fastNLP.core import SequentialSampler
from fastNLP.core.utils import _build_args


class Predictor(object):
@@ -13,37 +17,55 @@ class Predictor(object):
Currently, Predictor does not support GPU.
"""

def __init__(self):
def __init__(self, network):
if not isinstance(network, torch.nn.Module):
raise ValueError(
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network)))
self.network = network
self.batch_size = 1
self.batch_output = []

def predict(self, network, data):
def predict(self, data, seq_len_field_name=None):
"""Perform inference using the trained model.

:param network: a PyTorch model (cpu)
:param data: a DataSet object.
:param str seq_len_field_name: field name indicating sequence lengths
:return: list of batch outputs
"""
# turn on the testing mode; clean up the history
self.mode(network, test=True)
batch_output = []
if not isinstance(data, DataSet):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))

data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
self.network.eval()
batch_output = defaultdict(list)
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False,
prefetch=False)

for batch_x, _ in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
batch_output.append(prediction)
if hasattr(self.network, "predict"):
predict_func = self.network.predict
else:
predict_func = self.network.forward

return batch_output
with torch.no_grad():
for batch_x, _ in data_iterator:
refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x)

def mode(self, network, test=True):
if test:
network.eval()
else:
network.train()
if seq_len_field_name is not None:
seq_lens = batch_x[seq_len_field_name].tolist()

for key, value in prediction.items():
value = value.cpu().numpy()
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
batch_output[key].extend(value.tolist())
else:
if seq_len_field_name is not None:
tmp_batch = []
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
else:
batch_output[key].append(value)

def data_forward(self, network, x):
"""Forward through network."""
y = network(**x)
return y
return batch_output

+ 12
- 14
fastNLP/core/trainer.py View File

@@ -196,9 +196,9 @@ class Trainer(object):
print("training epochs started " + self.start_time, flush=True)

try:
self.callback_manager.before_train()
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.after_train(self.model)
self.callback_manager.on_train_end(self.model)
except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model)

@@ -237,28 +237,26 @@ class Trainer(object):
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping
self.callback_manager.before_epoch(epoch, self.n_epochs)
self.callback_manager.on_epoch_begin(epoch, self.n_epochs)
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices)
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x)

# edit prediction
self.callback_manager.before_loss(batch_y, prediction)
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y)
avg_loss += loss.item()

# Is loss NaN or inf? requires_grad = False
self.callback_manager.before_backward(loss, self.model)
self.callback_manager.on_backward_begin(loss, self.model)
self._grad_backward(loss)
# gradient clipping
self.callback_manager.after_backward(self.model)
self.callback_manager.on_backward_end(self.model)

self._update()
# lr scheduler; lr_finder; one_cycle
self.callback_manager.after_step(self.optimizer)
self.callback_manager.on_step_end(self.optimizer)

if (self.step+1) % self.print_every == 0:
if self.use_tqdm:
@@ -272,8 +270,7 @@ class Trainer(object):
pbar.set_postfix_str(print_output)
avg_loss = 0
self.step += 1
# do nothing
self.callback_manager.after_batch()
self.callback_manager.on_batch_end()

if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
@@ -287,12 +284,13 @@ class Trainer(object):
# ================= mini-batch end ==================== #

# lr decay; early stopping
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
# =============== epochs end =================== #
pbar.close()
# ============ tqdm end ============== #

def _do_validation(self, epoch, step):
self.callback_manager.on_valid_begin()
res = self.tester.test()

if self._better_eval_result(res):
@@ -305,7 +303,7 @@ class Trainer(object):
self.best_dev_epoch = epoch
self.best_dev_step = step
# get validation results; adjust optimizer
self.callback_manager.after_valid(res, self.metric_key, self.optimizer)
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer)
return res

def _mode(self, model, is_test=False):


+ 1
- 6
reproduction/Biaffine_parser/run.py View File

@@ -4,19 +4,14 @@ import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

import fastNLP
import torch

from fastNLP.core.trainer import Trainer
from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.tester import Tester
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.io.model_io import ModelLoader
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.io.model_io import ModelSaver
from fastNLP.io.dataset_loader import ConllxDataLoader
from fastNLP.api.processor import *
from fastNLP.io.embed_loader import EmbedLoader
@@ -172,7 +167,7 @@ def train(path):
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)

class MyCallback(Callback):
def after_step(self, optimizer):
def on_step_end(self, optimizer):
step = self.trainer.step
# learning rate decay
if step > 0 and step % 1000 == 0:


+ 21
- 6
test/core/test_predictor.py View File

@@ -1,4 +1,5 @@
import unittest
from collections import defaultdict

import numpy as np
import torch
@@ -23,12 +24,26 @@ def prepare_fake_dataset():
return data_set


class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = Linear(2, 1)

def forward(self, x):
return {"predict": self.linear(x)}


class TestPredictor(unittest.TestCase):
def test(self):
predictor = Predictor()
model = Linear(2, 1)
def test_simple(self):
model = LinearModel()
predictor = Predictor(model)
data = prepare_fake_dataset()
data.set_input("x")
ans = predictor.predict(model, data)
self.assertEqual(len(ans), 2000)
self.assertTrue(isinstance(ans[0], torch.Tensor))
ans = predictor.predict(data)
self.assertTrue(isinstance(ans, defaultdict))
self.assertTrue("predict" in ans)
self.assertTrue(isinstance(ans["predict"], list))

def test_sequence(self):
# test sequence input/output
pass

Loading…
Cancel
Save