新增CheckPointCallback用于恢复Trainer的训练; CheckPointCallback会将optimizer,model以及Trainer的状态恢复到保存的epochtags/v0.5.5
@@ -38,11 +38,13 @@ __all__ = [ | |||||
'SaveModelCallback', | 'SaveModelCallback', | ||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError", | "EarlyStopError", | ||||
"CheckPointCallback", | |||||
"Padder", | "Padder", | ||||
"AutoPadder", | "AutoPadder", | ||||
"EngChar2DPadder", | "EngChar2DPadder", | ||||
"MetricBase", | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"CMRC2018Metric", | "CMRC2018Metric", | ||||
@@ -50,7 +50,8 @@ __all__ = [ | |||||
'SaveModelCallback', | 'SaveModelCallback', | ||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError", | "EarlyStopError", | ||||
"CheckPointCallback", | |||||
"LossFunc", | "LossFunc", | ||||
"CrossEntropyLoss", | "CrossEntropyLoss", | ||||
"L1Loss", | "L1Loss", | ||||
@@ -58,7 +59,8 @@ __all__ = [ | |||||
"NLLLoss", | "NLLLoss", | ||||
"LossInForward", | "LossInForward", | ||||
"CMRC2018Loss", | "CMRC2018Loss", | ||||
"MetricBase", | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"CMRC2018Metric", | "CMRC2018Metric", | ||||
@@ -79,13 +81,13 @@ from ._logger import logger | |||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | from .batch import DataSetIter, BatchIter, TorchLoaderIter | ||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | ||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ | LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ | ||||
EarlyStopError | |||||
EarlyStopError, CheckPointCallback | |||||
from .const import Const | from .const import Const | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss | ||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase | |||||
from .optimizer import Optimizer, SGD, Adam, AdamW | from .optimizer import Optimizer, SGD, Adam, AdamW | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
@@ -64,7 +64,8 @@ __all__ = [ | |||||
"SaveModelCallback", | "SaveModelCallback", | ||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError" | |||||
"EarlyStopError", | |||||
"CheckPointCallback" | |||||
] | ] | ||||
import os | import os | ||||
@@ -869,6 +870,106 @@ class TensorboardCallback(Callback): | |||||
del self._summary_writer | del self._summary_writer | ||||
class CheckPointCallback(Callback): | |||||
def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True): | |||||
""" | |||||
用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练 | |||||
一段示例代码 | |||||
Example1:: | |||||
>>> callback = CheckPointCallback('chkp.pt') | |||||
>>> trainer = Trainer(xxx, callback=callback) | |||||
>>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数) | |||||
Example2:: | |||||
>>> fitlog.set_log_dir('xxx') | |||||
>>> callback = CheckPointCallback('chkp.pt') # 一定要在set_log_dir下一行就接着CheckPointCallback | |||||
>>> trainer = Trainer(xxx, callback=callback) | |||||
>>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数) | |||||
:param str save_path: 将状态保存到哪个位置。需要指定一个具体的路径,比如'checkpoints/chtp.pt'。如果检查到该文件存在,将在 | |||||
Trainer开始训练的时候自动从这个Checkpoint处开始运行。 | |||||
:param bool delete_when_train_finish: 如果Train正常运行完毕,是否自动删除。删除该文件可以使得路径自动复用。 | |||||
:param bool recovery_fitlog: 是否恢复fitlog为对应的log,如果为True请将本Callback放在fitlog.set_log_dir后面一行初始化。 | |||||
如果为False,将新建一个log folder否则继续使用之前的。 | |||||
""" | |||||
super().__init__() | |||||
self.save_path = os.path.abspath(os.path.expanduser(save_path)) | |||||
self.delete_when_train_finish = delete_when_train_finish | |||||
self.recover_fitlog = recovery_fitlog | |||||
if os.path.exists(os.path.expanduser(self.save_path)): | |||||
logger.info("The train will start from the checkpoint saved in {}.".format(self.save_path)) | |||||
if self.recover_fitlog: | |||||
states = torch.load(self.save_path) | |||||
if 'fitlog_log_dir' in states: | |||||
try: | |||||
import fitlog | |||||
log_dir = states['fitlog_log_dir'] | |||||
if 'fitlog_save_log_dir' in states: | |||||
log_dir = states['fitlog_save_log_dir'] | |||||
fitlog.set_log_dir(log_dir, new_log=True) | |||||
except: | |||||
print("Fail to recovery the fitlog states.") | |||||
def on_train_begin(self): | |||||
""" | |||||
当train开始时,且需要恢复上次训练时,会做以下的操作 | |||||
(1) 重新加载model权重 | |||||
(2) 重新加载optimizer的状态 | |||||
(3) 加载当前epoch数 | |||||
(4) 加载当前最佳evaluate的性能 | |||||
(5) (optional) 自动将fitlog设置到上次log出继续 | |||||
:return: | |||||
""" | |||||
if os.path.exists(os.path.expanduser(self.save_path)): | |||||
states = torch.load(self.save_path) | |||||
self.model.load_state_dict(states['model']) | |||||
self.optimizer.load_state_dict(states['optimizer']) | |||||
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | |||||
self.trainer.step = states['step'] | |||||
if 'best_dev_epoch' in states: | |||||
self.trainer.best_dev_perf = states['best_dev_perf'] | |||||
self.trainer.best_dev_epoch = states['best_dev_epoch'] | |||||
self.trainer.best_dev_step = states['best_dev_step'] | |||||
self.trainer.best_metric_indicator = states['best_metric_indicator'] | |||||
def on_epoch_end(self): | |||||
""" | |||||
保存状态,使得结果可以被恢复 | |||||
:param self: | |||||
:return: | |||||
""" | |||||
states = {} | |||||
states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()} | |||||
states['optimizer'] = self.optimizer.state_dict() | |||||
states['epoch'] = self.epoch | |||||
states['step'] = self.step | |||||
if self.trainer.best_dev_epoch is not None: | |||||
states['best_dev_epoch'] = self.trainer.best_dev_epoch | |||||
states['best_dev_perf'] = self.trainer.best_dev_perf | |||||
states['best_dev_step'] = self.trainer.best_dev_step | |||||
states['best_metric_indicator'] = self.trainer.best_metric_indicator | |||||
if self.recover_fitlog: | |||||
try: | |||||
import fitlog | |||||
if fitlog._logger._log_dir is not None: | |||||
states['fitlog_log_dir'] = fitlog._logger._log_dir | |||||
if fitlog._logger._save_log_dir is not None: | |||||
states['fitlog_save_log_dir'] = fitlog._logger._save_log_dir | |||||
except: | |||||
pass | |||||
torch.save(states, self.save_path) | |||||
def on_train_end(self): | |||||
# 训练结束,根据情况删除保存的内容 | |||||
if self.delete_when_train_finish: | |||||
if os.path.exists(self.save_path): | |||||
os.remove(self.save_path) | |||||
class WarmupCallback(Callback): | class WarmupCallback(Callback): | ||||
""" | """ | ||||
learning rate按照一定的速率从0上升到设置的learning rate。 | learning rate按照一定的速率从0上升到设置的learning rate。 | ||||
@@ -562,7 +562,6 @@ class Trainer(object): | |||||
verbose=0, | verbose=0, | ||||
use_tqdm=self.test_use_tqdm) | use_tqdm=self.test_use_tqdm) | ||||
self.step = 0 | |||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
if isinstance(callbacks, Callback): | if isinstance(callbacks, Callback): | ||||
@@ -603,7 +602,8 @@ class Trainer(object): | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
start_time = time.time() | start_time = time.time() | ||||
self.logger.info("training epochs started " + self.start_time) | self.logger.info("training epochs started " + self.start_time) | ||||
self.step = 0 | |||||
self.epoch = 1 | |||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
@@ -642,14 +642,12 @@ class Trainer(object): | |||||
from .utils import _pseudo_tqdm as inner_tqdm | from .utils import _pseudo_tqdm as inner_tqdm | ||||
else: | else: | ||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
self.step = 0 | |||||
self.epoch = 0 | |||||
start = time.time() | start = time.time() | ||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
self.pbar = pbar | self.pbar = pbar | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.batch_per_epoch = self.data_iterator.num_batches | self.batch_per_epoch = self.data_iterator.num_batches | ||||
for epoch in range(1, self.n_epochs + 1): | |||||
for epoch in range(self.epoch, self.n_epochs + 1): | |||||
self.epoch = epoch | self.epoch = epoch | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
@@ -26,8 +26,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||||
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | ||||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 | |||||
中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(paths, (str, Path)): | if isinstance(paths, (str, Path)): | ||||
@@ -69,8 +69,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||||
for key, value in paths.items(): | for key, value in paths.items(): | ||||
if isinstance(key, str) and isinstance(value, str): | if isinstance(key, str) and isinstance(value, str): | ||||
value = os.path.abspath(os.path.expanduser(value)) | value = os.path.abspath(os.path.expanduser(value)) | ||||
if not os.path.isfile(value): | |||||
raise TypeError(f"{value} is not a valid file.") | |||||
if not os.path.exists(value): | |||||
raise TypeError(f"{value} is not a valid path.") | |||||
paths[key] = value | paths[key] = value | ||||
else: | else: | ||||
raise TypeError("All keys and values in paths should be str.") | raise TypeError("All keys and values in paths should be str.") | ||||
@@ -41,8 +41,8 @@ class TestCallback(unittest.TestCase): | |||||
self.tempdir = tempfile.mkdtemp() | self.tempdir = tempfile.mkdtemp() | ||||
def tearDown(self): | def tearDown(self): | ||||
pass | |||||
# shutil.rmtree(self.tempdir) | |||||
import shutil | |||||
shutil.rmtree(self.tempdir) | |||||
def test_gradient_clip(self): | def test_gradient_clip(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
@@ -145,7 +145,54 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | ||||
callbacks=fitlog_callback, check_code_level=2) | callbacks=fitlog_callback, check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
def test_CheckPointCallback(self): | |||||
from fastNLP import CheckPointCallback, Callback | |||||
from fastNLP import Tester | |||||
class RaiseCallback(Callback): | |||||
def __init__(self, stop_step=10): | |||||
super().__init__() | |||||
self.stop_step = stop_step | |||||
def on_backward_begin(self, loss): | |||||
if self.step > self.stop_step: | |||||
raise RuntimeError() | |||||
data_set, model = prepare_env() | |||||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | |||||
import fitlog | |||||
fitlog.set_log_dir(self.tempdir) | |||||
tempfile_path = os.path.join(self.tempdir, 'chkt.pt') | |||||
callbacks = [CheckPointCallback(tempfile_path)] | |||||
fitlog_callback = FitlogCallback(data_set, tester) | |||||
callbacks.append(fitlog_callback) | |||||
callbacks.append(RaiseCallback(100)) | |||||
try: | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||||
callbacks=callbacks, check_code_level=2) | |||||
trainer.train() | |||||
except: | |||||
pass | |||||
# 用下面的代码模拟重新运行 | |||||
data_set, model = prepare_env() | |||||
callbacks = [CheckPointCallback(tempfile_path)] | |||||
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) | |||||
fitlog_callback = FitlogCallback(data_set, tester) | |||||
callbacks.append(fitlog_callback) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | |||||
callbacks=callbacks, check_code_level=2) | |||||
trainer.train() | |||||
def test_save_model_callback(self): | def test_save_model_callback(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
top = 3 | top = 3 | ||||