diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 5495dbec..919554c5 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -1,10 +1,11 @@ +import itertools from collections import defaultdict import torch from fastNLP.core.batch import Batch from fastNLP.core.sampler import RandomSampler - +from fastNLP.core.utils import _build_args class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ @@ -40,7 +41,12 @@ class Tester(object): output[k].append(v) for k, v in batch_y.items(): truths[k].append(v) - eval_results = self.evaluate(**output, **truths) + for k, v in output.items(): + output[k] = itertools.chain(*v) + for k, v in truths.items(): + truths[k] = itertools.chain(*v) + args = _build_args(self._evaluator, **output, **truths) + eval_results = self._evaluator(**args) print("[tester] {}".format(self.print_eval_results(eval_results))) self.mode(network, is_test=False) self.metrics = eval_results @@ -60,14 +66,10 @@ class Tester(object): def data_forward(self, network, x): """A forward pass of the model. """ + x = _build_args(network.forward, **x) y = network(**x) return y - def evaluate(self, **kwargs): - """Compute evaluation metrics. - """ - return self._evaluator(**kwargs) - def print_eval_results(self, results): """Override this method to support more print formats. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2a6458c6..a21f2ded 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -21,9 +21,8 @@ class Trainer(object): """ def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, - dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", + dev_data=None, use_cuda=False, save_path="./save", optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), - evaluator=Evaluator(), **kwargs): super(Trainer, self).__init__() @@ -36,9 +35,16 @@ class Trainer(object): self.save_path = str(save_path) self.print_every = int(print_every) - self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get() - self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) - self.evaluator = evaluator + model_name = model.__class__.__name__ + assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) + self.loss_func = self.model.get_loss + if isinstance(optimizer, torch.optim.Optimizer): + self.optimizer = optimizer + else: + self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) + + assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name) + self.evaluator = self.model.evaluate if self.dev_data is not None: valid_args = {"batch_size": self.batch_size, "save_path": self.save_path, @@ -48,7 +54,10 @@ class Trainer(object): for k, v in kwargs.items(): setattr(self, k, v) - self._summary_writer = SummaryWriter(os.path.join(self.save_path, 'tensorboard_logs')) + self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs') + if os.path.exists(self.tensorboard_path): + os.rmdir(self.tensorboard_path) + self._summary_writer = SummaryWriter(self.tensorboard_path) self._graph_summaried = False self.step = 0 self.start_time = None # start timestamp @@ -138,6 +147,7 @@ class Trainer(object): self.optimizer.step() def data_forward(self, network, x): + x = _build_args(network.forward, **x) y = network(**x) if not self._graph_summaried: # self._summary_writer.add_graph(network, x, verbose=False) @@ -161,12 +171,9 @@ class Trainer(object): :param truth: ground truth label vector :return: a scalar """ - if isinstance(predict, dict) and isinstance(truth, dict): - return self.loss_func(**predict, **truth) - if len(truth) > 1: - raise NotImplementedError("Not ready to handle multi-labels.") - truth = list(truth.values())[0] if len(truth) > 0 else None - return self.loss_func(predict, truth) + assert isinstance(predict, dict) and isinstance(truth, dict) + args = _build_args(self.loss_func, **predict, **truth) + return self.loss_func(**args) def save_model(self, model, model_name, only_param=False): model_name = os.path.join(self.save_path, model_name) diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index e814717b..04f0c6d9 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -46,5 +46,18 @@ class CNNText(torch.nn.Module): x = self.fc(x) # [N,C] -> [N, N_class] return {'output':x} - def loss(self, output, label_seq): + def predict(self, word_seq): + output = self(word_seq) + _, predict = output.max(dim=1) + return {'predict': predict} + + def get_loss(self, output, label_seq): return self._loss(output, label_seq) + + def evaluate(self, predict, label_seq): + predict, label_seq = torch.stack(predict, dim=0), torch.stack(label_seq, dim=0) + predict, label_seq = predict.squeeze(), label_seq.squeeze() + correct = (predict == label_seq).long().sum().item() + total = label_seq.size(0) + return 1.0 * correct / total +