Browse Source

add validate_every in trainer

tags/v0.2.0
yunfan 6 years ago
parent
commit
b78d86584c
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      fastNLP/core/trainer.py

+ 8
- 3
fastNLP/core/trainer.py View File

@@ -25,7 +25,7 @@ class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1,
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save", dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
**kwargs): **kwargs):
@@ -39,6 +39,7 @@ class Trainer(object):
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
self.save_path = str(save_path) self.save_path = str(save_path)
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every)


model_name = model.__class__.__name__ model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
@@ -94,7 +95,8 @@ class Trainer(object):


self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)


if self.dev_data:
# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self.do_validation() self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time) self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1 epoch += 1
@@ -128,10 +130,13 @@ class Trainer(object):
if n_print > 0 and self.step % n_print == 0: if n_print > 0 and self.step % n_print == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - start)) diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff) epoch, self.step, loss.data, diff)
print(print_output) print(print_output)


if self.validate_every > 0 and self.step % self.validate_every == 0:
self.do_validation()

self.step += 1 self.step += 1


def do_validation(self): def do_validation(self):


Loading…
Cancel
Save