From b78d86584ccd9edb7a62298de42992e243ba3f7d Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 26 Nov 2018 18:35:48 +0800 Subject: [PATCH] add validate_every in trainer --- fastNLP/core/trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b4aa3b65..6e439c47 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -25,7 +25,7 @@ class Trainer(object): """Main Training Loop """ - def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, + def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), **kwargs): @@ -39,6 +39,7 @@ class Trainer(object): self.use_cuda = bool(use_cuda) self.save_path = str(save_path) self.print_every = int(print_every) + self.validate_every = int(validate_every) model_name = model.__class__.__name__ assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) @@ -94,7 +95,8 @@ class Trainer(object): self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) - if self.dev_data: + # validate_every override validation at end of epochs + if self.dev_data and self.validate_every <= 0: self.do_validation() self.save_model(self.model, 'training_model_' + self.start_time) epoch += 1 @@ -128,10 +130,13 @@ class Trainer(object): if n_print > 0 and self.step % n_print == 0: end = time.time() diff = timedelta(seconds=round(end - start)) - print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( + print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( epoch, self.step, loss.data, diff) print(print_output) + if self.validate_every > 0 and self.step % self.validate_every == 0: + self.do_validation() + self.step += 1 def do_validation(self):