From a6dbbe9812f301f1e3dfcc02d984ee53dad0df5d Mon Sep 17 00:00:00 2001 From: ChenXin Date: Tue, 15 Jan 2019 11:45:02 +0800 Subject: [PATCH] remove the gpu_id info when saving --- fastNLP/core/trainer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6dcc9c78..9cc5431c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -365,12 +365,23 @@ class Trainer(object): return self.losser(predict, truth) def _save_model(self, model, model_name, only_param=False): + """ 存储不含有显卡信息的state_dict或model + :param model: + :param model_name: + :param only_param: + :return: + """ if self.save_path is not None: - model_name = os.path.join(self.save_path, model_name) + model_path = os.path.join(self.save_path, model_name) if only_param: - torch.save(model.state_dict(), model_name) + state_dict = model.state_dict() + for key in state_dict: + state_dict[key] = state_dict[key].cpu() + torch.save(state_dict, model_path) else: - torch.save(model, model_name) + model.cpu() + torch.save(model, model_path) + model.cuda() def _load_model(self, model, model_name, only_param=False): # 返回bool值指示是否成功reload模型