|
@@ -365,12 +365,23 @@ class Trainer(object): |
|
|
return self.losser(predict, truth) |
|
|
return self.losser(predict, truth) |
|
|
|
|
|
|
|
|
def _save_model(self, model, model_name, only_param=False): |
|
|
def _save_model(self, model, model_name, only_param=False): |
|
|
|
|
|
""" 存储不含有显卡信息的state_dict或model |
|
|
|
|
|
:param model: |
|
|
|
|
|
:param model_name: |
|
|
|
|
|
:param only_param: |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
if self.save_path is not None: |
|
|
if self.save_path is not None: |
|
|
model_name = os.path.join(self.save_path, model_name) |
|
|
|
|
|
|
|
|
model_path = os.path.join(self.save_path, model_name) |
|
|
if only_param: |
|
|
if only_param: |
|
|
torch.save(model.state_dict(), model_name) |
|
|
|
|
|
|
|
|
state_dict = model.state_dict() |
|
|
|
|
|
for key in state_dict: |
|
|
|
|
|
state_dict[key] = state_dict[key].cpu() |
|
|
|
|
|
torch.save(state_dict, model_path) |
|
|
else: |
|
|
else: |
|
|
torch.save(model, model_name) |
|
|
|
|
|
|
|
|
model.cpu() |
|
|
|
|
|
torch.save(model, model_path) |
|
|
|
|
|
model.cuda() |
|
|
|
|
|
|
|
|
def _load_model(self, model, model_name, only_param=False): |
|
|
def _load_model(self, model, model_name, only_param=False): |
|
|
# 返回bool值指示是否成功reload模型 |
|
|
# 返回bool值指示是否成功reload模型 |
|
|