Browse Source

update SaveModelCallback和CheckPointCallback, 使他们可以在DataParallel的情况下正常工作

tags/v0.5.5
yh_cc 5 years ago
parent
commit
7024df46af
2 changed files with 15 additions and 4 deletions
  1. +12
    -3
      fastNLP/core/callback.py
  2. +3
    -1
      fastNLP/core/utils.py

+ 12
- 3
fastNLP/core/callback.py View File

@@ -87,6 +87,7 @@ from .dataset import DataSet
from .tester import Tester
from ._logger import logger
from .utils import _check_fp16
from ._parallel_utils import _model_contains_inner_module

try:
import fitlog
@@ -914,7 +915,7 @@ class CheckPointCallback(Callback):
log_dir = states['fitlog_save_log_dir']
fitlog.set_log_dir(log_dir, new_log=True)
except:
print("Fail to recovery the fitlog states.")
logger.error("Fail to recovery the fitlog states.")

def on_train_begin(self):
"""
@@ -929,7 +930,10 @@ class CheckPointCallback(Callback):
"""
if os.path.exists(os.path.expanduser(self.save_path)):
states = torch.load(self.save_path)
self.model.load_state_dict(states['model'])
model = self.model
if _model_contains_inner_module(model):
model = model.module
model.load_state_dict(states['model'])
self.optimizer.load_state_dict(states['optimizer'])
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始
self.trainer.step = states['step']
@@ -947,7 +951,10 @@ class CheckPointCallback(Callback):
:return:
"""
states = {}
states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()}
model = self.model
if _model_contains_inner_module(model):
model = model.module
states['model'] = {name:param.cpu() for name, param in model.state_dict().items()}
states['optimizer'] = self.optimizer.state_dict()
states['epoch'] = self.epoch
states['step'] = self.step
@@ -966,12 +973,14 @@ class CheckPointCallback(Callback):
except:
pass
torch.save(states, self.save_path)
logger.debug("Checkpoint:{} has been saved in epoch:{}.".format(self.save_path, self.epoch))

def on_train_end(self):
# 训练结束,根据情况删除保存的内容
if self.delete_when_train_finish:
if os.path.exists(self.save_path):
os.remove(self.save_path)
logger.debug("Checkpoint:{} has been removed.".format(self.save_path))


class WarmupCallback(Callback):


+ 3
- 1
fastNLP/core/utils.py View File

@@ -19,6 +19,8 @@ import torch.nn as nn
from typing import List
from ._logger import logger
from prettytable import PrettyTable
from ._parallel_utils import _model_contains_inner_module

try:
from apex import amp
except:
@@ -179,7 +181,7 @@ def _save_model(model, model_name, save_dir, only_param=False):
model_path = os.path.join(save_dir, model_name)
if not os.path.isdir(save_dir):
os.makedirs(save_dir, exist_ok=True)
if isinstance(model, nn.DataParallel):
if _model_contains_inner_module(model):
model = model.module
if only_param:
state_dict = model.state_dict()


Loading…
Cancel
Save