Browse Source

修复一个小bug

新增CheckPointCallback用于恢复Trainer的训练; CheckPointCallback会将optimizer,model以及Trainer的状态恢复到保存的epoch
tags/v0.5.5
yh_cc 4 years ago
parent
commit
9293a6c1ab
6 changed files with 169 additions and 19 deletions
  1. +4
    -2
      fastNLP/__init__.py
  2. +6
    -4
      fastNLP/core/__init__.py
  3. +102
    -1
      fastNLP/core/callback.py
  4. +3
    -5
      fastNLP/core/trainer.py
  5. +4
    -4
      fastNLP/io/utils.py
  6. +50
    -3
      test/core/test_callbacks.py

+ 4
- 2
fastNLP/__init__.py View File

@@ -38,11 +38,13 @@ __all__ = [
'SaveModelCallback',
"CallbackException",
"EarlyStopError",
"CheckPointCallback",

"Padder",
"AutoPadder",
"EngChar2DPadder",

"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",
"CMRC2018Metric",


+ 6
- 4
fastNLP/core/__init__.py View File

@@ -50,7 +50,8 @@ __all__ = [
'SaveModelCallback',
"CallbackException",
"EarlyStopError",
"CheckPointCallback",

"LossFunc",
"CrossEntropyLoss",
"L1Loss",
@@ -58,7 +59,8 @@ __all__ = [
"NLLLoss",
"LossInForward",
"CMRC2018Loss",

"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",
"CMRC2018Metric",
@@ -79,13 +81,13 @@ from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \
EarlyStopError
EarlyStopError, CheckPointCallback
from .const import Const
from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase
from .optimizer import Optimizer, SGD, Adam, AdamW
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester


+ 102
- 1
fastNLP/core/callback.py View File

@@ -64,7 +64,8 @@ __all__ = [
"SaveModelCallback",
"CallbackException",
"EarlyStopError"
"EarlyStopError",
"CheckPointCallback"
]

import os
@@ -869,6 +870,106 @@ class TensorboardCallback(Callback):
del self._summary_writer


class CheckPointCallback(Callback):
def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True):
"""
用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练
一段示例代码
Example1::

>>> callback = CheckPointCallback('chkp.pt')
>>> trainer = Trainer(xxx, callback=callback)
>>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数)

Example2::

>>> fitlog.set_log_dir('xxx')
>>> callback = CheckPointCallback('chkp.pt') # 一定要在set_log_dir下一行就接着CheckPointCallback
>>> trainer = Trainer(xxx, callback=callback)
>>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数)

:param str save_path: 将状态保存到哪个位置。需要指定一个具体的路径,比如'checkpoints/chtp.pt'。如果检查到该文件存在,将在
Trainer开始训练的时候自动从这个Checkpoint处开始运行。
:param bool delete_when_train_finish: 如果Train正常运行完毕,是否自动删除。删除该文件可以使得路径自动复用。
:param bool recovery_fitlog: 是否恢复fitlog为对应的log,如果为True请将本Callback放在fitlog.set_log_dir后面一行初始化。
如果为False,将新建一个log folder否则继续使用之前的。
"""
super().__init__()
self.save_path = os.path.abspath(os.path.expanduser(save_path))
self.delete_when_train_finish = delete_when_train_finish
self.recover_fitlog = recovery_fitlog
if os.path.exists(os.path.expanduser(self.save_path)):
logger.info("The train will start from the checkpoint saved in {}.".format(self.save_path))
if self.recover_fitlog:
states = torch.load(self.save_path)
if 'fitlog_log_dir' in states:
try:
import fitlog
log_dir = states['fitlog_log_dir']
if 'fitlog_save_log_dir' in states:
log_dir = states['fitlog_save_log_dir']
fitlog.set_log_dir(log_dir, new_log=True)
except:
print("Fail to recovery the fitlog states.")

def on_train_begin(self):
"""
当train开始时,且需要恢复上次训练时,会做以下的操作
(1) 重新加载model权重
(2) 重新加载optimizer的状态
(3) 加载当前epoch数
(4) 加载当前最佳evaluate的性能
(5) (optional) 自动将fitlog设置到上次log出继续

:return:
"""
if os.path.exists(os.path.expanduser(self.save_path)):
states = torch.load(self.save_path)
self.model.load_state_dict(states['model'])
self.optimizer.load_state_dict(states['optimizer'])
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始
self.trainer.step = states['step']
if 'best_dev_epoch' in states:
self.trainer.best_dev_perf = states['best_dev_perf']
self.trainer.best_dev_epoch = states['best_dev_epoch']
self.trainer.best_dev_step = states['best_dev_step']
self.trainer.best_metric_indicator = states['best_metric_indicator']

def on_epoch_end(self):
"""
保存状态,使得结果可以被恢复

:param self:
:return:
"""
states = {}
states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()}
states['optimizer'] = self.optimizer.state_dict()
states['epoch'] = self.epoch
states['step'] = self.step
if self.trainer.best_dev_epoch is not None:
states['best_dev_epoch'] = self.trainer.best_dev_epoch
states['best_dev_perf'] = self.trainer.best_dev_perf
states['best_dev_step'] = self.trainer.best_dev_step
states['best_metric_indicator'] = self.trainer.best_metric_indicator
if self.recover_fitlog:
try:
import fitlog
if fitlog._logger._log_dir is not None:
states['fitlog_log_dir'] = fitlog._logger._log_dir
if fitlog._logger._save_log_dir is not None:
states['fitlog_save_log_dir'] = fitlog._logger._save_log_dir
except:
pass
torch.save(states, self.save_path)

def on_train_end(self):
# 训练结束,根据情况删除保存的内容
if self.delete_when_train_finish:
if os.path.exists(self.save_path):
os.remove(self.save_path)


class WarmupCallback(Callback):
"""
learning rate按照一定的速率从0上升到设置的learning rate。


+ 3
- 5
fastNLP/core/trainer.py View File

@@ -562,7 +562,6 @@ class Trainer(object):
verbose=0,
use_tqdm=self.test_use_tqdm)

self.step = 0
self.start_time = None # start timestamp

if isinstance(callbacks, Callback):
@@ -603,7 +602,8 @@ class Trainer(object):
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time()
self.logger.info("training epochs started " + self.start_time)

self.step = 0
self.epoch = 1
try:
self.callback_manager.on_train_begin()
self._train()
@@ -642,14 +642,12 @@ class Trainer(object):
from .utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0
self.epoch = 0
start = time.time()
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar
avg_loss = 0
self.batch_per_epoch = self.data_iterator.num_batches
for epoch in range(1, self.n_epochs + 1):
for epoch in range(self.epoch, self.n_epochs + 1):
self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping


+ 4
- 4
fastNLP/io/utils.py View File

@@ -26,8 +26,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:

如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。

:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名
中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
if isinstance(paths, (str, Path)):
@@ -69,8 +69,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
value = os.path.abspath(os.path.expanduser(value))
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
if not os.path.exists(value):
raise TypeError(f"{value} is not a valid path.")
paths[key] = value
else:
raise TypeError("All keys and values in paths should be str.")


+ 50
- 3
test/core/test_callbacks.py View File

@@ -41,8 +41,8 @@ class TestCallback(unittest.TestCase):
self.tempdir = tempfile.mkdtemp()
def tearDown(self):
pass
# shutil.rmtree(self.tempdir)
import shutil
shutil.rmtree(self.tempdir)
def test_gradient_clip(self):
data_set, model = prepare_env()
@@ -145,7 +145,54 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=fitlog_callback, check_code_level=2)
trainer.train()

def test_CheckPointCallback(self):

from fastNLP import CheckPointCallback, Callback
from fastNLP import Tester

class RaiseCallback(Callback):
def __init__(self, stop_step=10):
super().__init__()
self.stop_step = stop_step

def on_backward_begin(self, loss):
if self.step > self.stop_step:
raise RuntimeError()

data_set, model = prepare_env()
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y"))
import fitlog

fitlog.set_log_dir(self.tempdir)
tempfile_path = os.path.join(self.tempdir, 'chkt.pt')
callbacks = [CheckPointCallback(tempfile_path)]

fitlog_callback = FitlogCallback(data_set, tester)
callbacks.append(fitlog_callback)

callbacks.append(RaiseCallback(100))
try:
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=callbacks, check_code_level=2)
trainer.train()
except:
pass
# 用下面的代码模拟重新运行
data_set, model = prepare_env()
callbacks = [CheckPointCallback(tempfile_path)]
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y"))
fitlog_callback = FitlogCallback(data_set, tester)
callbacks.append(fitlog_callback)

trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=callbacks, check_code_level=2)
trainer.train()

def test_save_model_callback(self):
data_set, model = prepare_env()
top = 3


Loading…
Cancel
Save