diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6dcc9c78..9cc5431c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -365,12 +365,23 @@ class Trainer(object): return self.losser(predict, truth) def _save_model(self, model, model_name, only_param=False): + """ 存储不含有显卡信息的state_dict或model + :param model: + :param model_name: + :param only_param: + :return: + """ if self.save_path is not None: - model_name = os.path.join(self.save_path, model_name) + model_path = os.path.join(self.save_path, model_name) if only_param: - torch.save(model.state_dict(), model_name) + state_dict = model.state_dict() + for key in state_dict: + state_dict[key] = state_dict[key].cpu() + torch.save(state_dict, model_path) else: - torch.save(model, model_name) + model.cpu() + torch.save(model, model_path) + model.cuda() def _load_model(self, model, model_name, only_param=False): # 返回bool值指示是否成功reload模型