Browse Source

add validate_every in trainer

tags/v0.2.0
yunfan 6 years ago
parent
commit
b78d86584c
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      fastNLP/core/trainer.py

+ 8
- 3
fastNLP/core/trainer.py View File

@@ -25,7 +25,7 @@ class Trainer(object):
"""Main Training Loop

"""
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1,
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
**kwargs):
@@ -39,6 +39,7 @@ class Trainer(object):
self.use_cuda = bool(use_cuda)
self.save_path = str(save_path)
self.print_every = int(print_every)
self.validate_every = int(validate_every)

model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
@@ -94,7 +95,8 @@ class Trainer(object):

self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)

if self.dev_data:
# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1
@@ -128,10 +130,13 @@ class Trainer(object):
if n_print > 0 and self.step % n_print == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)

if self.validate_every > 0 and self.step % self.validate_every == 0:
self.do_validation()

self.step += 1

def do_validation(self):


Loading…
Cancel
Save