Browse Source

remove the gpu_id info when saving

tags/v0.3.1^2
ChenXin 6 years ago
parent
commit
a6dbbe9812
1 changed files with 14 additions and 3 deletions
  1. +14
    -3
      fastNLP/core/trainer.py

+ 14
- 3
fastNLP/core/trainer.py View File

@@ -365,12 +365,23 @@ class Trainer(object):
return self.losser(predict, truth) return self.losser(predict, truth)


def _save_model(self, model, model_name, only_param=False): def _save_model(self, model, model_name, only_param=False):
""" 存储不含有显卡信息的state_dict或model
:param model:
:param model_name:
:param only_param:
:return:
"""
if self.save_path is not None: if self.save_path is not None:
model_name = os.path.join(self.save_path, model_name)
model_path = os.path.join(self.save_path, model_name)
if only_param: if only_param:
torch.save(model.state_dict(), model_name)
state_dict = model.state_dict()
for key in state_dict:
state_dict[key] = state_dict[key].cpu()
torch.save(state_dict, model_path)
else: else:
torch.save(model, model_name)
model.cpu()
torch.save(model, model_path)
model.cuda()


def _load_model(self, model, model_name, only_param=False): def _load_model(self, model, model_name, only_param=False):
# 返回bool值指示是否成功reload模型 # 返回bool值指示是否成功reload模型


Loading…
Cancel
Save