* rename callback methods. Use fastai's notation. * add a new callback method - on_valid_begintags/v0.3.1^2
@@ -17,37 +17,40 @@ class Callback(object): | |||
super(Callback, self).__init__() | |||
self.trainer = None # 在Trainer内部被重新赋值 | |||
def before_train(self): | |||
def on_train_begin(self): | |||
# before the main training loop | |||
pass | |||
def before_epoch(self, cur_epoch, total_epoch): | |||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||
# at the beginning of each epoch | |||
pass | |||
def before_batch(self, batch_x, batch_y, indices): | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
# at the beginning of each step/mini-batch | |||
pass | |||
def before_loss(self, batch_y, predict_y): | |||
def on_loss_begin(self, batch_y, predict_y): | |||
# after data_forward, and before loss computation | |||
pass | |||
def before_backward(self, loss, model): | |||
def on_backward_begin(self, loss, model): | |||
# after loss computation, and before gradient backward | |||
pass | |||
def after_backward(self, model): | |||
def on_backward_end(self, model): | |||
pass | |||
def after_step(self, optimizer): | |||
def on_step_end(self, optimizer): | |||
pass | |||
def after_batch(self, *args): | |||
def on_batch_end(self, *args): | |||
# at the end of each step/mini-batch | |||
pass | |||
def after_valid(self, eval_result, metric_key, optimizer): | |||
def on_valid_begin(self): | |||
pass | |||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||
""" | |||
每次执行验证机的evaluation后会调用。传入eval_result | |||
@@ -58,7 +61,7 @@ class Callback(object): | |||
""" | |||
pass | |||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||
""" | |||
每个epoch结束将会调用该方法 | |||
@@ -69,7 +72,7 @@ class Callback(object): | |||
""" | |||
pass | |||
def after_train(self, model): | |||
def on_train_end(self, model): | |||
""" | |||
训练结束,调用该方法 | |||
@@ -134,47 +137,51 @@ class CallbackManager(Callback): | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
@transfer | |||
def before_train(self): | |||
def on_train_begin(self): | |||
pass | |||
@transfer | |||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||
pass | |||
@transfer | |||
def before_epoch(self, cur_epoch, total_epoch): | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
pass | |||
@transfer | |||
def before_batch(self, batch_x, batch_y, indices): | |||
def on_loss_begin(self, batch_y, predict_y): | |||
pass | |||
@transfer | |||
def before_loss(self, batch_y, predict_y): | |||
def on_backward_begin(self, loss, model): | |||
pass | |||
@transfer | |||
def before_backward(self, loss, model): | |||
def on_backward_end(self, model): | |||
pass | |||
@transfer | |||
def after_backward(self, model): | |||
def on_step_end(self, optimizer): | |||
pass | |||
@transfer | |||
def after_step(self, optimizer): | |||
def on_batch_end(self): | |||
pass | |||
@transfer | |||
def after_batch(self): | |||
def on_valid_begin(self): | |||
pass | |||
@transfer | |||
def after_valid(self, eval_result, metric_key, optimizer): | |||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||
pass | |||
@transfer | |||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||
pass | |||
@transfer | |||
def after_train(self, model): | |||
def on_train_end(self, model): | |||
pass | |||
@transfer | |||
@@ -183,36 +190,36 @@ class CallbackManager(Callback): | |||
class DummyCallback(Callback): | |||
def before_train(self, *arg): | |||
def on_train_begin(self, *arg): | |||
print(arg) | |||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||
print(cur_epoch, n_epoch, optimizer) | |||
class EchoCallback(Callback): | |||
def before_train(self): | |||
def on_train_begin(self): | |||
print("before_train") | |||
def before_epoch(self, cur_epoch, total_epoch): | |||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||
print("before_epoch") | |||
def before_batch(self, batch_x, batch_y, indices): | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
print("before_batch") | |||
def before_loss(self, batch_y, predict_y): | |||
def on_loss_begin(self, batch_y, predict_y): | |||
print("before_loss") | |||
def before_backward(self, loss, model): | |||
def on_backward_begin(self, loss, model): | |||
print("before_backward") | |||
def after_batch(self): | |||
def on_batch_end(self): | |||
print("after_batch") | |||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||
print("after_epoch") | |||
def after_train(self, model): | |||
def on_train_end(self, model): | |||
print("after_train") | |||
@@ -240,7 +247,7 @@ class GradientClipCallback(Callback): | |||
self.parameters = parameters | |||
self.clip_value = clip_value | |||
def after_backward(self, model): | |||
def on_backward_end(self, model): | |||
self.clip_fun(model.parameters(), self.clip_value) | |||
@@ -266,7 +273,7 @@ class EarlyStopCallback(Callback): | |||
self.wait = 0 | |||
self.epoch = 0 | |||
def after_valid(self, eval_result, metric_key, optimizer): | |||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||
self.epoch += 1 | |||
if not self.trainer._better_eval_result(eval_result): | |||
# current result is getting worse | |||
@@ -297,7 +304,7 @@ class LRScheduler(Callback): | |||
else: | |||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | |||
def before_epoch(self, cur_epoch, total_epoch): | |||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||
self.scheduler.step() | |||
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) | |||
@@ -359,7 +366,7 @@ class LRFinder(Callback): | |||
self.find = None | |||
self.loader = ModelLoader() | |||
def before_epoch(self, cur_epoch, total_epoch): | |||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||
if cur_epoch == 1: | |||
self.opt = self.trainer.optimizer # pytorch optimizer | |||
self.opt.param_groups[0]["lr"] = self.start_lr | |||
@@ -367,7 +374,7 @@ class LRFinder(Callback): | |||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | |||
self.find = True | |||
def before_backward(self, loss, model): | |||
def on_backward_begin(self, loss, model): | |||
if self.find: | |||
if torch.isnan(loss) or self.stop is True: | |||
self.stop = True | |||
@@ -379,7 +386,7 @@ class LRFinder(Callback): | |||
self.best_loss = self.smooth_value.smooth | |||
self.best_lr = self.opt.param_groups[0]["lr"] | |||
def after_batch(self, *args): | |||
def on_batch_end(self, *args): | |||
if self.find: | |||
lr = next(self.lr_gen, None) | |||
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: | |||
@@ -388,7 +395,7 @@ class LRFinder(Callback): | |||
self.opt.param_groups[0]["lr"] = lr | |||
# self.loader.load_pytorch(self.trainer.model, "tmp") | |||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||
if cur_epoch == 1: | |||
self.opt.param_groups[0]["lr"] = self.best_lr | |||
self.find = False | |||
@@ -415,7 +422,7 @@ class TensorboardCallback(Callback): | |||
self._summary_writer = None | |||
self.graph_added = False | |||
def before_train(self): | |||
def on_train_begin(self): | |||
save_dir = self.trainer.save_path | |||
if save_dir is None: | |||
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||
@@ -423,7 +430,7 @@ class TensorboardCallback(Callback): | |||
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||
self._summary_writer = SummaryWriter(path) | |||
def before_batch(self, batch_x, batch_y, indices): | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
if "model" in self.options and self.graph_added is False: | |||
# tesorboardX 这里有大bug,暂时没法画模型图 | |||
# from fastNLP.core.utils import _build_args | |||
@@ -433,7 +440,7 @@ class TensorboardCallback(Callback): | |||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | |||
self.graph_added = True | |||
def before_backward(self, loss, model): | |||
def on_backward_begin(self, loss, model): | |||
if "loss" in self.options: | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | |||
@@ -445,14 +452,14 @@ class TensorboardCallback(Callback): | |||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | |||
global_step=self.trainer.step) | |||
def after_valid(self, eval_result, metric_key, optimizer): | |||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||
if "metric" in self.options: | |||
for name, metric in eval_result.items(): | |||
for metric_key, metric_val in metric.items(): | |||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||
global_step=self.trainer.step) | |||
def after_train(self, model): | |||
def on_train_end(self, model): | |||
self._summary_writer.close() | |||
del self._summary_writer | |||
@@ -464,5 +471,5 @@ class TensorboardCallback(Callback): | |||
if __name__ == "__main__": | |||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | |||
manager.before_train(10, 11, 12) | |||
manager.on_train_begin(10, 11, 12) | |||
# print(manager.after_epoch()) |
@@ -1,7 +1,11 @@ | |||
from collections import defaultdict | |||
import torch | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core import Batch | |||
from fastNLP.core import DataSet | |||
from fastNLP.core import SequentialSampler | |||
from fastNLP.core.utils import _build_args | |||
class Predictor(object): | |||
@@ -13,37 +17,55 @@ class Predictor(object): | |||
Currently, Predictor does not support GPU. | |||
""" | |||
def __init__(self): | |||
def __init__(self, network): | |||
if not isinstance(network, torch.nn.Module): | |||
raise ValueError( | |||
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) | |||
self.network = network | |||
self.batch_size = 1 | |||
self.batch_output = [] | |||
def predict(self, network, data): | |||
def predict(self, data, seq_len_field_name=None): | |||
"""Perform inference using the trained model. | |||
:param network: a PyTorch model (cpu) | |||
:param data: a DataSet object. | |||
:param str seq_len_field_name: field name indicating sequence lengths | |||
:return: list of batch outputs | |||
""" | |||
# turn on the testing mode; clean up the history | |||
self.mode(network, test=True) | |||
batch_output = [] | |||
if not isinstance(data, DataSet): | |||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
self.network.eval() | |||
batch_output = defaultdict(list) | |||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, | |||
prefetch=False) | |||
for batch_x, _ in data_iterator: | |||
with torch.no_grad(): | |||
prediction = self.data_forward(network, batch_x) | |||
batch_output.append(prediction) | |||
if hasattr(self.network, "predict"): | |||
predict_func = self.network.predict | |||
else: | |||
predict_func = self.network.forward | |||
return batch_output | |||
with torch.no_grad(): | |||
for batch_x, _ in data_iterator: | |||
refined_batch_x = _build_args(predict_func, **batch_x) | |||
prediction = predict_func(**refined_batch_x) | |||
def mode(self, network, test=True): | |||
if test: | |||
network.eval() | |||
else: | |||
network.train() | |||
if seq_len_field_name is not None: | |||
seq_lens = batch_x[seq_len_field_name].tolist() | |||
for key, value in prediction.items(): | |||
value = value.cpu().numpy() | |||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | |||
batch_output[key].extend(value.tolist()) | |||
else: | |||
if seq_len_field_name is not None: | |||
tmp_batch = [] | |||
for idx, seq_len in enumerate(seq_lens): | |||
tmp_batch.append(value[idx, :seq_len]) | |||
batch_output[key].extend(tmp_batch) | |||
else: | |||
batch_output[key].append(value) | |||
def data_forward(self, network, x): | |||
"""Forward through network.""" | |||
y = network(**x) | |||
return y | |||
return batch_output |
@@ -196,9 +196,9 @@ class Trainer(object): | |||
print("training epochs started " + self.start_time, flush=True) | |||
try: | |||
self.callback_manager.before_train() | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.after_train(self.model) | |||
self.callback_manager.on_train_end(self.model) | |||
except (CallbackException, KeyboardInterrupt) as e: | |||
self.callback_manager.on_exception(e, self.model) | |||
@@ -237,28 +237,26 @@ class Trainer(object): | |||
for epoch in range(1, self.n_epochs+1): | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
self.callback_manager.before_epoch(epoch, self.n_epochs) | |||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
indices = data_iterator.get_batch_indices() | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.before_batch(batch_x, batch_y, indices) | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.before_loss(batch_y, prediction) | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y) | |||
avg_loss += loss.item() | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.before_backward(loss, self.model) | |||
self.callback_manager.on_backward_begin(loss, self.model) | |||
self._grad_backward(loss) | |||
# gradient clipping | |||
self.callback_manager.after_backward(self.model) | |||
self.callback_manager.on_backward_end(self.model) | |||
self._update() | |||
# lr scheduler; lr_finder; one_cycle | |||
self.callback_manager.after_step(self.optimizer) | |||
self.callback_manager.on_step_end(self.optimizer) | |||
if (self.step+1) % self.print_every == 0: | |||
if self.use_tqdm: | |||
@@ -272,8 +270,7 @@ class Trainer(object): | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.step += 1 | |||
# do nothing | |||
self.callback_manager.after_batch() | |||
self.callback_manager.on_batch_end() | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
@@ -287,12 +284,13 @@ class Trainer(object): | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) | |||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
# ============ tqdm end ============== # | |||
def _do_validation(self, epoch, step): | |||
self.callback_manager.on_valid_begin() | |||
res = self.tester.test() | |||
if self._better_eval_result(res): | |||
@@ -305,7 +303,7 @@ class Trainer(object): | |||
self.best_dev_epoch = epoch | |||
self.best_dev_step = step | |||
# get validation results; adjust optimizer | |||
self.callback_manager.after_valid(res, self.metric_key, self.optimizer) | |||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) | |||
return res | |||
def _mode(self, model, is_test=False): | |||
@@ -4,19 +4,14 @@ import sys | |||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
import fastNLP | |||
import torch | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.tester import Tester | |||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||
from fastNLP.io.model_io import ModelLoader | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
from fastNLP.io.model_io import ModelSaver | |||
from fastNLP.io.dataset_loader import ConllxDataLoader | |||
from fastNLP.api.processor import * | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
@@ -172,7 +167,7 @@ def train(path): | |||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | |||
class MyCallback(Callback): | |||
def after_step(self, optimizer): | |||
def on_step_end(self, optimizer): | |||
step = self.trainer.step | |||
# learning rate decay | |||
if step > 0 and step % 1000 == 0: | |||
@@ -1,4 +1,5 @@ | |||
import unittest | |||
from collections import defaultdict | |||
import numpy as np | |||
import torch | |||
@@ -23,12 +24,26 @@ def prepare_fake_dataset(): | |||
return data_set | |||
class LinearModel(torch.nn.Module): | |||
def __init__(self): | |||
super(LinearModel, self).__init__() | |||
self.linear = Linear(2, 1) | |||
def forward(self, x): | |||
return {"predict": self.linear(x)} | |||
class TestPredictor(unittest.TestCase): | |||
def test(self): | |||
predictor = Predictor() | |||
model = Linear(2, 1) | |||
def test_simple(self): | |||
model = LinearModel() | |||
predictor = Predictor(model) | |||
data = prepare_fake_dataset() | |||
data.set_input("x") | |||
ans = predictor.predict(model, data) | |||
self.assertEqual(len(ans), 2000) | |||
self.assertTrue(isinstance(ans[0], torch.Tensor)) | |||
ans = predictor.predict(data) | |||
self.assertTrue(isinstance(ans, defaultdict)) | |||
self.assertTrue("predict" in ans) | |||
self.assertTrue(isinstance(ans["predict"], list)) | |||
def test_sequence(self): | |||
# test sequence input/output | |||
pass |