|
|
@@ -25,7 +25,7 @@ class Trainer(object): |
|
|
|
"""Main Training Loop |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, |
|
|
|
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, |
|
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), |
|
|
|
**kwargs): |
|
|
@@ -39,6 +39,7 @@ class Trainer(object): |
|
|
|
self.use_cuda = bool(use_cuda) |
|
|
|
self.save_path = str(save_path) |
|
|
|
self.print_every = int(print_every) |
|
|
|
self.validate_every = int(validate_every) |
|
|
|
|
|
|
|
model_name = model.__class__.__name__ |
|
|
|
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) |
|
|
@@ -94,7 +95,8 @@ class Trainer(object): |
|
|
|
|
|
|
|
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) |
|
|
|
|
|
|
|
if self.dev_data: |
|
|
|
# validate_every override validation at end of epochs |
|
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
|
self.do_validation() |
|
|
|
self.save_model(self.model, 'training_model_' + self.start_time) |
|
|
|
epoch += 1 |
|
|
@@ -128,10 +130,13 @@ class Trainer(object): |
|
|
|
if n_print > 0 and self.step % n_print == 0: |
|
|
|
end = time.time() |
|
|
|
diff = timedelta(seconds=round(end - start)) |
|
|
|
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( |
|
|
|
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( |
|
|
|
epoch, self.step, loss.data, diff) |
|
|
|
print(print_output) |
|
|
|
|
|
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0: |
|
|
|
self.do_validation() |
|
|
|
|
|
|
|
self.step += 1 |
|
|
|
|
|
|
|
def do_validation(self): |
|
|
|