|
|
@@ -617,6 +617,14 @@ class Trainer(object): |
|
|
|
elif on_exception == 'raise': |
|
|
|
raise e |
|
|
|
|
|
|
|
if self.dev_data is not None and self.best_dev_perf is not None and load_best_model: |
|
|
|
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) |
|
|
|
load_succeed = self._load_model(self.model, model_name) |
|
|
|
if load_succeed: |
|
|
|
self.logger.info("Reloaded the best model.") |
|
|
|
else: |
|
|
|
self.logger.info("Fail to reload best model.") |
|
|
|
finally: |
|
|
|
if self.dev_data is not None and self.best_dev_perf is not None: |
|
|
|
self.logger.info( |
|
|
|
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step)) |
|
|
@@ -624,15 +632,7 @@ class Trainer(object): |
|
|
|
results['best_eval'] = self.best_dev_perf |
|
|
|
results['best_epoch'] = self.best_dev_epoch |
|
|
|
results['best_step'] = self.best_dev_step |
|
|
|
if load_best_model: |
|
|
|
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) |
|
|
|
load_succeed = self._load_model(self.model, model_name) |
|
|
|
if load_succeed: |
|
|
|
self.logger.info("Reloaded the best model.") |
|
|
|
else: |
|
|
|
self.logger.info("Fail to reload best model.") |
|
|
|
finally: |
|
|
|
pass |
|
|
|
|
|
|
|
results['seconds'] = round(time.time() - start_time, 2) |
|
|
|
|
|
|
|
return results |
|
|
|