diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 919554c5..9f9661fd 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -10,28 +10,32 @@ from fastNLP.core.utils import _build_args class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ - def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs): + def __init__(self, data, model, batch_size, use_cuda, save_path="./save/", **kwargs): super(Tester, self).__init__() - + self.use_cuda = use_cuda + self.data = data self.batch_size = batch_size self.pickle_path = save_path - self.use_cuda = use_cuda - self._evaluator = evaluator - - self._model = None - self.eval_history = [] # evaluation results of all batches - - def test(self, network, dev_data): if torch.cuda.is_available() and self.use_cuda: - self._model = network.cuda() + self._model = model.cuda() else: - self._model = network + self._model = model + if hasattr(self._model, 'predict'): + assert callable(self._model.predict) + self._predict_func = self._model.predict + else: + self._predict_func = self._model + assert hasattr(model, 'evaluate') + self._evaluator = model.evaluate + self.eval_history = [] # evaluation results of all batches + def test(self): # turn on the testing mode; clean up the history + network = self._model self.mode(network, is_test=True) self.eval_history.clear() output, truths = defaultdict(list), defaultdict(list) - data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False) + data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) with torch.no_grad(): for batch_x, batch_y in data_iterator: @@ -67,7 +71,7 @@ class Tester(object): def data_forward(self, network, x): """A forward pass of the model. """ x = _build_args(network.forward, **x) - y = network(**x) + y = self._predict_func(**x) return y def print_eval_results(self, results): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d83e3936..b4aa3b65 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -4,9 +4,10 @@ from datetime import datetime import warnings from collections import defaultdict import os -import itertools +import shutil from tensorboardX import SummaryWriter +import torch from fastNLP.core.batch import Batch from fastNLP.core.loss import Loss @@ -51,17 +52,18 @@ class Trainer(object): self.evaluator = self.model.evaluate if self.dev_data is not None: - valid_args = {"batch_size": self.batch_size, "save_path": self.save_path, - "use_cuda": self.use_cuda, "evaluator": self.evaluator} - self.tester = Tester(**valid_args) + self.tester = Tester(model=self.model, + data=self.dev_data, + batch_size=self.batch_size, + save_path=self.save_path, + use_cuda=self.use_cuda) for k, v in kwargs.items(): setattr(self, k, v) self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs') if os.path.exists(self.tensorboard_path): - os.rmdir(self.tensorboard_path) - self._summary_writer = SummaryWriter(self.tensorboard_path) + shutil.rmtree(self.tensorboard_path) self._graph_summaried = False self.step = 0 self.start_time = None # start timestamp @@ -73,26 +75,32 @@ class Trainer(object): :return: """ - if torch.cuda.is_available() and self.use_cuda: - self.model = self.model.cuda() + try: + self._summary_writer = SummaryWriter(self.tensorboard_path) - self.mode(self.model, is_test=False) + if torch.cuda.is_available() and self.use_cuda: + self.model = self.model.cuda() - start = time.time() - self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) - print("training epochs started " + self.start_time) + self.mode(self.model, is_test=False) - epoch = 1 - while epoch <= self.n_epochs: + start = time.time() + self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) + print("training epochs started " + self.start_time) - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) + epoch = 1 + while epoch <= self.n_epochs: - self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) - if self.dev_data: - self.do_validation() - self.save_model(self.model, 'training_model_' + self.start_time) - epoch += 1 + self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) + + if self.dev_data: + self.do_validation() + self.save_model(self.model, 'training_model_' + self.start_time) + epoch += 1 + finally: + self._summary_writer.close() + del self._summary_writer def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): """Training process in one epoch. @@ -127,7 +135,7 @@ class Trainer(object): self.step += 1 def do_validation(self): - res = self.tester.test(self.model, self.dev_data) + res = self.tester.test() for name, num in res.items(): self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) self.save_model(self.model, 'best_model_' + self.start_time) diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 04f0c6d9..a4dcfef2 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -48,16 +48,16 @@ class CNNText(torch.nn.Module): def predict(self, word_seq): output = self(word_seq) - _, predict = output.max(dim=1) + _, predict = output['output'].max(dim=1) return {'predict': predict} def get_loss(self, output, label_seq): return self._loss(output, label_seq) def evaluate(self, predict, label_seq): - predict, label_seq = torch.stack(predict, dim=0), torch.stack(label_seq, dim=0) + predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) predict, label_seq = predict.squeeze(), label_seq.squeeze() correct = (predict == label_seq).long().sum().item() total = label_seq.size(0) - return 1.0 * correct / total + return {'acc': 1.0 * correct / total}