From 7024df46af9c7222c6e356978295091cda9b074d Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 27 Dec 2019 21:32:02 +0800 Subject: [PATCH] =?UTF-8?q?update=20SaveModelCallback=E5=92=8CCheckPointCa?= =?UTF-8?q?llback,=20=E4=BD=BF=E4=BB=96=E4=BB=AC=E5=8F=AF=E4=BB=A5?= =?UTF-8?q?=E5=9C=A8DataParallel=E7=9A=84=E6=83=85=E5=86=B5=E4=B8=8B?= =?UTF-8?q?=E6=AD=A3=E5=B8=B8=E5=B7=A5=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 15 ++++++++++++--- fastNLP/core/utils.py | 4 +++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index b4eb95df..7162ee69 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -87,6 +87,7 @@ from .dataset import DataSet from .tester import Tester from ._logger import logger from .utils import _check_fp16 +from ._parallel_utils import _model_contains_inner_module try: import fitlog @@ -914,7 +915,7 @@ class CheckPointCallback(Callback): log_dir = states['fitlog_save_log_dir'] fitlog.set_log_dir(log_dir, new_log=True) except: - print("Fail to recovery the fitlog states.") + logger.error("Fail to recovery the fitlog states.") def on_train_begin(self): """ @@ -929,7 +930,10 @@ class CheckPointCallback(Callback): """ if os.path.exists(os.path.expanduser(self.save_path)): states = torch.load(self.save_path) - self.model.load_state_dict(states['model']) + model = self.model + if _model_contains_inner_module(model): + model = model.module + model.load_state_dict(states['model']) self.optimizer.load_state_dict(states['optimizer']) self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 self.trainer.step = states['step'] @@ -947,7 +951,10 @@ class CheckPointCallback(Callback): :return: """ states = {} - states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()} + model = self.model + if _model_contains_inner_module(model): + model = model.module + states['model'] = {name:param.cpu() for name, param in model.state_dict().items()} states['optimizer'] = self.optimizer.state_dict() states['epoch'] = self.epoch states['step'] = self.step @@ -966,12 +973,14 @@ class CheckPointCallback(Callback): except: pass torch.save(states, self.save_path) + logger.debug("Checkpoint:{} has been saved in epoch:{}.".format(self.save_path, self.epoch)) def on_train_end(self): # 训练结束,根据情况删除保存的内容 if self.delete_when_train_finish: if os.path.exists(self.save_path): os.remove(self.save_path) + logger.debug("Checkpoint:{} has been removed.".format(self.save_path)) class WarmupCallback(Callback): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 27f266fa..ba9ec850 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -19,6 +19,8 @@ import torch.nn as nn from typing import List from ._logger import logger from prettytable import PrettyTable +from ._parallel_utils import _model_contains_inner_module + try: from apex import amp except: @@ -179,7 +181,7 @@ def _save_model(model, model_name, save_dir, only_param=False): model_path = os.path.join(save_dir, model_name) if not os.path.isdir(save_dir): os.makedirs(save_dir, exist_ok=True) - if isinstance(model, nn.DataParallel): + if _model_contains_inner_module(model): model = model.module if only_param: state_dict = model.state_dict()