新增CheckPointCallback用于恢复Trainer的训练; CheckPointCallback会将optimizer,model以及Trainer的状态恢复到保存的epochtags/v0.5.5
@@ -38,11 +38,13 @@ __all__ = [ | |||
'SaveModelCallback', | |||
"CallbackException", | |||
"EarlyStopError", | |||
"CheckPointCallback", | |||
"Padder", | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
"MetricBase", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"CMRC2018Metric", | |||
@@ -50,7 +50,8 @@ __all__ = [ | |||
'SaveModelCallback', | |||
"CallbackException", | |||
"EarlyStopError", | |||
"CheckPointCallback", | |||
"LossFunc", | |||
"CrossEntropyLoss", | |||
"L1Loss", | |||
@@ -58,7 +59,8 @@ __all__ = [ | |||
"NLLLoss", | |||
"LossInForward", | |||
"CMRC2018Loss", | |||
"MetricBase", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"CMRC2018Metric", | |||
@@ -79,13 +81,13 @@ from ._logger import logger | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | |||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ | |||
EarlyStopError | |||
EarlyStopError, CheckPointCallback | |||
from .const import Const | |||
from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
from .instance import Instance | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss | |||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric | |||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase | |||
from .optimizer import Optimizer, SGD, Adam, AdamW | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
from .tester import Tester | |||
@@ -64,7 +64,8 @@ __all__ = [ | |||
"SaveModelCallback", | |||
"CallbackException", | |||
"EarlyStopError" | |||
"EarlyStopError", | |||
"CheckPointCallback" | |||
] | |||
import os | |||
@@ -869,6 +870,106 @@ class TensorboardCallback(Callback): | |||
del self._summary_writer | |||
class CheckPointCallback(Callback): | |||
def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True): | |||
""" | |||
用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练 | |||
一段示例代码 | |||
Example1:: | |||
>>> callback = CheckPointCallback('chkp.pt') | |||
>>> trainer = Trainer(xxx, callback=callback) | |||
>>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数) | |||
Example2:: | |||
>>> fitlog.set_log_dir('xxx') | |||
>>> callback = CheckPointCallback('chkp.pt') # 一定要在set_log_dir下一行就接着CheckPointCallback | |||
>>> trainer = Trainer(xxx, callback=callback) | |||
>>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数) | |||
:param str save_path: 将状态保存到哪个位置。需要指定一个具体的路径,比如'checkpoints/chtp.pt'。如果检查到该文件存在,将在 | |||
Trainer开始训练的时候自动从这个Checkpoint处开始运行。 | |||
:param bool delete_when_train_finish: 如果Train正常运行完毕,是否自动删除。删除该文件可以使得路径自动复用。 | |||
:param bool recovery_fitlog: 是否恢复fitlog为对应的log,如果为True请将本Callback放在fitlog.set_log_dir后面一行初始化。 | |||
如果为False,将新建一个log folder否则继续使用之前的。 | |||
""" | |||
super().__init__() | |||
self.save_path = os.path.abspath(os.path.expanduser(save_path)) | |||
self.delete_when_train_finish = delete_when_train_finish | |||
self.recover_fitlog = recovery_fitlog | |||
if os.path.exists(os.path.expanduser(self.save_path)): | |||
logger.info("The train will start from the checkpoint saved in {}.".format(self.save_path)) | |||
if self.recover_fitlog: | |||
states = torch.load(self.save_path) | |||
if 'fitlog_log_dir' in states: | |||
try: | |||
import fitlog | |||
log_dir = states['fitlog_log_dir'] | |||
if 'fitlog_save_log_dir' in states: | |||
log_dir = states['fitlog_save_log_dir'] | |||
fitlog.set_log_dir(log_dir, new_log=True) | |||
except: | |||
print("Fail to recovery the fitlog states.") | |||
def on_train_begin(self): | |||
""" | |||
当train开始时,且需要恢复上次训练时,会做以下的操作 | |||
(1) 重新加载model权重 | |||
(2) 重新加载optimizer的状态 | |||
(3) 加载当前epoch数 | |||
(4) 加载当前最佳evaluate的性能 | |||
(5) (optional) 自动将fitlog设置到上次log出继续 | |||
:return: | |||
""" | |||
if os.path.exists(os.path.expanduser(self.save_path)): | |||
states = torch.load(self.save_path) | |||
self.model.load_state_dict(states['model']) | |||
self.optimizer.load_state_dict(states['optimizer']) | |||
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | |||
self.trainer.step = states['step'] | |||
if 'best_dev_epoch' in states: | |||
self.trainer.best_dev_perf = states['best_dev_perf'] | |||
self.trainer.best_dev_epoch = states['best_dev_epoch'] | |||
self.trainer.best_dev_step = states['best_dev_step'] | |||
self.trainer.best_metric_indicator = states['best_metric_indicator'] | |||
def on_epoch_end(self): | |||
""" | |||
保存状态,使得结果可以被恢复 | |||
:param self: | |||
:return: | |||
""" | |||
states = {} | |||
states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()} | |||
states['optimizer'] = self.optimizer.state_dict() | |||
states['epoch'] = self.epoch | |||
states['step'] = self.step | |||
if self.trainer.best_dev_epoch is not None: | |||
states['best_dev_epoch'] = self.trainer.best_dev_epoch | |||
states['best_dev_perf'] = self.trainer.best_dev_perf | |||
states['best_dev_step'] = self.trainer.best_dev_step | |||
states['best_metric_indicator'] = self.trainer.best_metric_indicator | |||
if self.recover_fitlog: | |||
try: | |||
import fitlog | |||
if fitlog._logger._log_dir is not None: | |||
states['fitlog_log_dir'] = fitlog._logger._log_dir | |||
if fitlog._logger._save_log_dir is not None: | |||
states['fitlog_save_log_dir'] = fitlog._logger._save_log_dir | |||
except: | |||
pass | |||
torch.save(states, self.save_path) | |||
def on_train_end(self): | |||
# 训练结束,根据情况删除保存的内容 | |||
if self.delete_when_train_finish: | |||
if os.path.exists(self.save_path): | |||
os.remove(self.save_path) | |||
class WarmupCallback(Callback): | |||
""" | |||
learning rate按照一定的速率从0上升到设置的learning rate。 | |||
@@ -562,7 +562,6 @@ class Trainer(object): | |||
verbose=0, | |||
use_tqdm=self.test_use_tqdm) | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
if isinstance(callbacks, Callback): | |||
@@ -603,7 +602,8 @@ class Trainer(object): | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
start_time = time.time() | |||
self.logger.info("training epochs started " + self.start_time) | |||
self.step = 0 | |||
self.epoch = 1 | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
@@ -642,14 +642,12 @@ class Trainer(object): | |||
from .utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
self.epoch = 0 | |||
start = time.time() | |||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
self.pbar = pbar | |||
avg_loss = 0 | |||
self.batch_per_epoch = self.data_iterator.num_batches | |||
for epoch in range(1, self.n_epochs + 1): | |||
for epoch in range(self.epoch, self.n_epochs + 1): | |||
self.epoch = epoch | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
@@ -26,8 +26,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | |||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 | |||
中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
:return: | |||
""" | |||
if isinstance(paths, (str, Path)): | |||
@@ -69,8 +69,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||
for key, value in paths.items(): | |||
if isinstance(key, str) and isinstance(value, str): | |||
value = os.path.abspath(os.path.expanduser(value)) | |||
if not os.path.isfile(value): | |||
raise TypeError(f"{value} is not a valid file.") | |||
if not os.path.exists(value): | |||
raise TypeError(f"{value} is not a valid path.") | |||
paths[key] = value | |||
else: | |||
raise TypeError("All keys and values in paths should be str.") | |||
@@ -41,8 +41,8 @@ class TestCallback(unittest.TestCase): | |||
self.tempdir = tempfile.mkdtemp() | |||
def tearDown(self): | |||
pass | |||
# shutil.rmtree(self.tempdir) | |||
import shutil | |||
shutil.rmtree(self.tempdir) | |||
def test_gradient_clip(self): | |||
data_set, model = prepare_env() | |||
@@ -145,7 +145,54 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=fitlog_callback, check_code_level=2) | |||
trainer.train() | |||
def test_CheckPointCallback(self): | |||
from fastNLP import CheckPointCallback, Callback | |||
from fastNLP import Tester | |||
class RaiseCallback(Callback): | |||
def __init__(self, stop_step=10): | |||
super().__init__() | |||
self.stop_step = stop_step | |||
def on_backward_begin(self, loss): | |||
if self.step > self.stop_step: | |||
raise RuntimeError() | |||
data_set, model = prepare_env() | |||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | |||
import fitlog | |||
fitlog.set_log_dir(self.tempdir) | |||
tempfile_path = os.path.join(self.tempdir, 'chkt.pt') | |||
callbacks = [CheckPointCallback(tempfile_path)] | |||
fitlog_callback = FitlogCallback(data_set, tester) | |||
callbacks.append(fitlog_callback) | |||
callbacks.append(RaiseCallback(100)) | |||
try: | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=callbacks, check_code_level=2) | |||
trainer.train() | |||
except: | |||
pass | |||
# 用下面的代码模拟重新运行 | |||
data_set, model = prepare_env() | |||
callbacks = [CheckPointCallback(tempfile_path)] | |||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | |||
fitlog_callback = FitlogCallback(data_set, tester) | |||
callbacks.append(fitlog_callback) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||
callbacks=callbacks, check_code_level=2) | |||
trainer.train() | |||
def test_save_model_callback(self): | |||
data_set, model = prepare_env() | |||
top = 3 | |||