diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 290a89c1..61969c2e 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -718,7 +718,7 @@ class Trainer(object): self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) elif self._load_best_model: - self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} + self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict()} self.best_dev_perf = res self.best_dev_epoch = epoch self.best_dev_step = step