|
@@ -718,7 +718,7 @@ class Trainer(object): |
|
|
self._save_model(self.model, |
|
|
self._save_model(self.model, |
|
|
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) |
|
|
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) |
|
|
elif self._load_best_model: |
|
|
elif self._load_best_model: |
|
|
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict()} |
|
|
|
|
|
|
|
|
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict().items()} |
|
|
self.best_dev_perf = res |
|
|
self.best_dev_perf = res |
|
|
self.best_dev_epoch = epoch |
|
|
self.best_dev_epoch = epoch |
|
|
self.best_dev_step = step |
|
|
self.best_dev_step = step |
|
|