diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 48d7333c..b1a480cc 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -17,37 +17,40 @@ class Callback(object): super(Callback, self).__init__() self.trainer = None # 在Trainer内部被重新赋值 - def before_train(self): + def on_train_begin(self): # before the main training loop pass - def before_epoch(self, cur_epoch, total_epoch): + def on_epoch_begin(self, cur_epoch, total_epoch): # at the beginning of each epoch pass - def before_batch(self, batch_x, batch_y, indices): + def on_batch_begin(self, batch_x, batch_y, indices): # at the beginning of each step/mini-batch pass - def before_loss(self, batch_y, predict_y): + def on_loss_begin(self, batch_y, predict_y): # after data_forward, and before loss computation pass - def before_backward(self, loss, model): + def on_backward_begin(self, loss, model): # after loss computation, and before gradient backward pass - def after_backward(self, model): + def on_backward_end(self, model): pass - def after_step(self, optimizer): + def on_step_end(self, optimizer): pass - def after_batch(self, *args): + def on_batch_end(self, *args): # at the end of each step/mini-batch pass - def after_valid(self, eval_result, metric_key, optimizer): + def on_valid_begin(self): + pass + + def on_valid_end(self, eval_result, metric_key, optimizer): """ 每次执行验证机的evaluation后会调用。传入eval_result @@ -58,7 +61,7 @@ class Callback(object): """ pass - def after_epoch(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self, cur_epoch, n_epoch, optimizer): """ 每个epoch结束将会调用该方法 @@ -69,7 +72,7 @@ class Callback(object): """ pass - def after_train(self, model): + def on_train_end(self, model): """ 训练结束,调用该方法 @@ -134,47 +137,51 @@ class CallbackManager(Callback): raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") @transfer - def before_train(self): + def on_train_begin(self): + pass + + @transfer + def on_epoch_begin(self, cur_epoch, total_epoch): pass @transfer - def before_epoch(self, cur_epoch, total_epoch): + def on_batch_begin(self, batch_x, batch_y, indices): pass @transfer - def before_batch(self, batch_x, batch_y, indices): + def on_loss_begin(self, batch_y, predict_y): pass @transfer - def before_loss(self, batch_y, predict_y): + def on_backward_begin(self, loss, model): pass @transfer - def before_backward(self, loss, model): + def on_backward_end(self, model): pass @transfer - def after_backward(self, model): + def on_step_end(self, optimizer): pass @transfer - def after_step(self, optimizer): + def on_batch_end(self): pass @transfer - def after_batch(self): + def on_valid_begin(self): pass @transfer - def after_valid(self, eval_result, metric_key, optimizer): + def on_valid_end(self, eval_result, metric_key, optimizer): pass @transfer - def after_epoch(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self, cur_epoch, n_epoch, optimizer): pass @transfer - def after_train(self, model): + def on_train_end(self, model): pass @transfer @@ -183,36 +190,36 @@ class CallbackManager(Callback): class DummyCallback(Callback): - def before_train(self, *arg): + def on_train_begin(self, *arg): print(arg) - def after_epoch(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self, cur_epoch, n_epoch, optimizer): print(cur_epoch, n_epoch, optimizer) class EchoCallback(Callback): - def before_train(self): + def on_train_begin(self): print("before_train") - def before_epoch(self, cur_epoch, total_epoch): + def on_epoch_begin(self, cur_epoch, total_epoch): print("before_epoch") - def before_batch(self, batch_x, batch_y, indices): + def on_batch_begin(self, batch_x, batch_y, indices): print("before_batch") - def before_loss(self, batch_y, predict_y): + def on_loss_begin(self, batch_y, predict_y): print("before_loss") - def before_backward(self, loss, model): + def on_backward_begin(self, loss, model): print("before_backward") - def after_batch(self): + def on_batch_end(self): print("after_batch") - def after_epoch(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self, cur_epoch, n_epoch, optimizer): print("after_epoch") - def after_train(self, model): + def on_train_end(self, model): print("after_train") @@ -240,7 +247,7 @@ class GradientClipCallback(Callback): self.parameters = parameters self.clip_value = clip_value - def after_backward(self, model): + def on_backward_end(self, model): self.clip_fun(model.parameters(), self.clip_value) @@ -266,7 +273,7 @@ class EarlyStopCallback(Callback): self.wait = 0 self.epoch = 0 - def after_valid(self, eval_result, metric_key, optimizer): + def on_valid_end(self, eval_result, metric_key, optimizer): self.epoch += 1 if not self.trainer._better_eval_result(eval_result): # current result is getting worse @@ -297,7 +304,7 @@ class LRScheduler(Callback): else: raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") - def before_epoch(self, cur_epoch, total_epoch): + def on_epoch_begin(self, cur_epoch, total_epoch): self.scheduler.step() print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) @@ -359,7 +366,7 @@ class LRFinder(Callback): self.find = None self.loader = ModelLoader() - def before_epoch(self, cur_epoch, total_epoch): + def on_epoch_begin(self, cur_epoch, total_epoch): if cur_epoch == 1: self.opt = self.trainer.optimizer # pytorch optimizer self.opt.param_groups[0]["lr"] = self.start_lr @@ -367,7 +374,7 @@ class LRFinder(Callback): ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) self.find = True - def before_backward(self, loss, model): + def on_backward_begin(self, loss, model): if self.find: if torch.isnan(loss) or self.stop is True: self.stop = True @@ -379,7 +386,7 @@ class LRFinder(Callback): self.best_loss = self.smooth_value.smooth self.best_lr = self.opt.param_groups[0]["lr"] - def after_batch(self, *args): + def on_batch_end(self, *args): if self.find: lr = next(self.lr_gen, None) if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: @@ -388,7 +395,7 @@ class LRFinder(Callback): self.opt.param_groups[0]["lr"] = lr # self.loader.load_pytorch(self.trainer.model, "tmp") - def after_epoch(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self, cur_epoch, n_epoch, optimizer): if cur_epoch == 1: self.opt.param_groups[0]["lr"] = self.best_lr self.find = False @@ -415,7 +422,7 @@ class TensorboardCallback(Callback): self._summary_writer = None self.graph_added = False - def before_train(self): + def on_train_begin(self): save_dir = self.trainer.save_path if save_dir is None: path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) @@ -423,7 +430,7 @@ class TensorboardCallback(Callback): path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) self._summary_writer = SummaryWriter(path) - def before_batch(self, batch_x, batch_y, indices): + def on_batch_begin(self, batch_x, batch_y, indices): if "model" in self.options and self.graph_added is False: # tesorboardX 这里有大bug,暂时没法画模型图 # from fastNLP.core.utils import _build_args @@ -433,7 +440,7 @@ class TensorboardCallback(Callback): # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) self.graph_added = True - def before_backward(self, loss, model): + def on_backward_begin(self, loss, model): if "loss" in self.options: self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) @@ -445,14 +452,14 @@ class TensorboardCallback(Callback): self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), global_step=self.trainer.step) - def after_valid(self, eval_result, metric_key, optimizer): + def on_valid_end(self, eval_result, metric_key, optimizer): if "metric" in self.options: for name, metric in eval_result.items(): for metric_key, metric_val in metric.items(): self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, global_step=self.trainer.step) - def after_train(self, model): + def on_train_end(self, model): self._summary_writer.close() del self._summary_writer @@ -464,5 +471,5 @@ class TensorboardCallback(Callback): if __name__ == "__main__": manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) - manager.before_train(10, 11, 12) + manager.on_train_begin(10, 11, 12) # print(manager.after_epoch()) diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index de9ddc8c..ae648e47 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -1,7 +1,11 @@ +from collections import defaultdict + import torch -from fastNLP.core.batch import Batch -from fastNLP.core.sampler import SequentialSampler +from fastNLP.core import Batch +from fastNLP.core import DataSet +from fastNLP.core import SequentialSampler +from fastNLP.core.utils import _build_args class Predictor(object): @@ -13,37 +17,55 @@ class Predictor(object): Currently, Predictor does not support GPU. """ - def __init__(self): + def __init__(self, network): + if not isinstance(network, torch.nn.Module): + raise ValueError( + "Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) + self.network = network self.batch_size = 1 self.batch_output = [] - def predict(self, network, data): + def predict(self, data, seq_len_field_name=None): """Perform inference using the trained model. - :param network: a PyTorch model (cpu) :param data: a DataSet object. + :param str seq_len_field_name: field name indicating sequence lengths :return: list of batch outputs """ - # turn on the testing mode; clean up the history - self.mode(network, test=True) - batch_output = [] + if not isinstance(data, DataSet): + raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) + if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: + raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) - data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) + self.network.eval() + batch_output = defaultdict(list) + data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, + prefetch=False) - for batch_x, _ in data_iterator: - with torch.no_grad(): - prediction = self.data_forward(network, batch_x) - batch_output.append(prediction) + if hasattr(self.network, "predict"): + predict_func = self.network.predict + else: + predict_func = self.network.forward - return batch_output + with torch.no_grad(): + for batch_x, _ in data_iterator: + refined_batch_x = _build_args(predict_func, **batch_x) + prediction = predict_func(**refined_batch_x) - def mode(self, network, test=True): - if test: - network.eval() - else: - network.train() + if seq_len_field_name is not None: + seq_lens = batch_x[seq_len_field_name].tolist() + + for key, value in prediction.items(): + value = value.cpu().numpy() + if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): + batch_output[key].extend(value.tolist()) + else: + if seq_len_field_name is not None: + tmp_batch = [] + for idx, seq_len in enumerate(seq_lens): + tmp_batch.append(value[idx, :seq_len]) + batch_output[key].extend(tmp_batch) + else: + batch_output[key].append(value) - def data_forward(self, network, x): - """Forward through network.""" - y = network(**x) - return y + return batch_output diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index ed2f366b..ddd35b28 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -196,9 +196,9 @@ class Trainer(object): print("training epochs started " + self.start_time, flush=True) try: - self.callback_manager.before_train() + self.callback_manager.on_train_begin() self._train() - self.callback_manager.after_train(self.model) + self.callback_manager.on_train_end(self.model) except (CallbackException, KeyboardInterrupt) as e: self.callback_manager.on_exception(e, self.model) @@ -237,28 +237,26 @@ class Trainer(object): for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping - self.callback_manager.before_epoch(epoch, self.n_epochs) + self.callback_manager.on_epoch_begin(epoch, self.n_epochs) for batch_x, batch_y in data_iterator: _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) indices = data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y - self.callback_manager.before_batch(batch_x, batch_y, indices) + self.callback_manager.on_batch_begin(batch_x, batch_y, indices) prediction = self._data_forward(self.model, batch_x) # edit prediction - self.callback_manager.before_loss(batch_y, prediction) + self.callback_manager.on_loss_begin(batch_y, prediction) loss = self._compute_loss(prediction, batch_y) avg_loss += loss.item() # Is loss NaN or inf? requires_grad = False - self.callback_manager.before_backward(loss, self.model) + self.callback_manager.on_backward_begin(loss, self.model) self._grad_backward(loss) - # gradient clipping - self.callback_manager.after_backward(self.model) + self.callback_manager.on_backward_end(self.model) self._update() - # lr scheduler; lr_finder; one_cycle - self.callback_manager.after_step(self.optimizer) + self.callback_manager.on_step_end(self.optimizer) if (self.step+1) % self.print_every == 0: if self.use_tqdm: @@ -272,8 +270,7 @@ class Trainer(object): pbar.set_postfix_str(print_output) avg_loss = 0 self.step += 1 - # do nothing - self.callback_manager.after_batch() + self.callback_manager.on_batch_end() if ((self.validate_every > 0 and self.step % self.validate_every == 0) or (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ @@ -287,12 +284,13 @@ class Trainer(object): # ================= mini-batch end ==================== # # lr decay; early stopping - self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) + self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) # =============== epochs end =================== # pbar.close() # ============ tqdm end ============== # def _do_validation(self, epoch, step): + self.callback_manager.on_valid_begin() res = self.tester.test() if self._better_eval_result(res): @@ -305,7 +303,7 @@ class Trainer(object): self.best_dev_epoch = epoch self.best_dev_step = step # get validation results; adjust optimizer - self.callback_manager.after_valid(res, self.metric_key, self.optimizer) + self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) return res def _mode(self, model, is_test=False): diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/Biaffine_parser/run.py index 98ef02fa..c226ce69 100644 --- a/reproduction/Biaffine_parser/run.py +++ b/reproduction/Biaffine_parser/run.py @@ -4,19 +4,14 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) import fastNLP -import torch from fastNLP.core.trainer import Trainer from fastNLP.core.instance import Instance from fastNLP.api.pipeline import Pipeline from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss -from fastNLP.core.vocabulary import Vocabulary -from fastNLP.core.dataset import DataSet from fastNLP.core.tester import Tester from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.io.model_io import ModelLoader -from fastNLP.io.embed_loader import EmbedLoader -from fastNLP.io.model_io import ModelSaver from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.api.processor import * from fastNLP.io.embed_loader import EmbedLoader @@ -172,7 +167,7 @@ def train(path): model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) class MyCallback(Callback): - def after_step(self, optimizer): + def on_step_end(self, optimizer): step = self.trainer.step # learning rate decay if step > 0 and step % 1000 == 0: diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index 8be5f289..c779e3ac 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -1,4 +1,5 @@ import unittest +from collections import defaultdict import numpy as np import torch @@ -23,12 +24,26 @@ def prepare_fake_dataset(): return data_set +class LinearModel(torch.nn.Module): + def __init__(self): + super(LinearModel, self).__init__() + self.linear = Linear(2, 1) + + def forward(self, x): + return {"predict": self.linear(x)} + + class TestPredictor(unittest.TestCase): - def test(self): - predictor = Predictor() - model = Linear(2, 1) + def test_simple(self): + model = LinearModel() + predictor = Predictor(model) data = prepare_fake_dataset() data.set_input("x") - ans = predictor.predict(model, data) - self.assertEqual(len(ans), 2000) - self.assertTrue(isinstance(ans[0], torch.Tensor)) + ans = predictor.predict(data) + self.assertTrue(isinstance(ans, defaultdict)) + self.assertTrue("predict" in ans) + self.assertTrue(isinstance(ans["predict"], list)) + + def test_sequence(self): + # test sequence input/output + pass