From 9293a6c1ab36bba1253347be561c0983c12a3457 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 26 Dec 2019 17:45:14 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=80=E4=B8=AA=E5=B0=8Fbu?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增CheckPointCallback用于恢复Trainer的训练; CheckPointCallback会将optimizer,model以及Trainer的状态恢复到保存的epoch --- fastNLP/__init__.py | 6 ++- fastNLP/core/__init__.py | 10 ++-- fastNLP/core/callback.py | 103 +++++++++++++++++++++++++++++++++++- fastNLP/core/trainer.py | 8 ++- fastNLP/io/utils.py | 8 +-- test/core/test_callbacks.py | 53 +++++++++++++++++-- 6 files changed, 169 insertions(+), 19 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 95d6f12c..76265a01 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -38,11 +38,13 @@ __all__ = [ 'SaveModelCallback', "CallbackException", "EarlyStopError", - + "CheckPointCallback", + "Padder", "AutoPadder", "EngChar2DPadder", - + + "MetricBase", "AccuracyMetric", "SpanFPreRecMetric", "CMRC2018Metric", diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 62bf4a77..18cdcac4 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -50,7 +50,8 @@ __all__ = [ 'SaveModelCallback', "CallbackException", "EarlyStopError", - + "CheckPointCallback", + "LossFunc", "CrossEntropyLoss", "L1Loss", @@ -58,7 +59,8 @@ __all__ = [ "NLLLoss", "LossInForward", "CMRC2018Loss", - + + "MetricBase", "AccuracyMetric", "SpanFPreRecMetric", "CMRC2018Metric", @@ -79,13 +81,13 @@ from ._logger import logger from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ - EarlyStopError + EarlyStopError, CheckPointCallback from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .instance import Instance from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss -from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric +from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase from .optimizer import Optimizer, SGD, Adam, AdamW from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .tester import Tester diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 095ebc3d..add039eb 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -64,7 +64,8 @@ __all__ = [ "SaveModelCallback", "CallbackException", - "EarlyStopError" + "EarlyStopError", + "CheckPointCallback" ] import os @@ -869,6 +870,106 @@ class TensorboardCallback(Callback): del self._summary_writer +class CheckPointCallback(Callback): + def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True): + """ + 用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练 + 一段示例代码 + Example1:: + + >>> callback = CheckPointCallback('chkp.pt') + >>> trainer = Trainer(xxx, callback=callback) + >>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数) + + Example2:: + + >>> fitlog.set_log_dir('xxx') + >>> callback = CheckPointCallback('chkp.pt') # 一定要在set_log_dir下一行就接着CheckPointCallback + >>> trainer = Trainer(xxx, callback=callback) + >>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数) + + :param str save_path: 将状态保存到哪个位置。需要指定一个具体的路径,比如'checkpoints/chtp.pt'。如果检查到该文件存在,将在 + Trainer开始训练的时候自动从这个Checkpoint处开始运行。 + :param bool delete_when_train_finish: 如果Train正常运行完毕,是否自动删除。删除该文件可以使得路径自动复用。 + :param bool recovery_fitlog: 是否恢复fitlog为对应的log,如果为True请将本Callback放在fitlog.set_log_dir后面一行初始化。 + 如果为False,将新建一个log folder否则继续使用之前的。 + """ + super().__init__() + self.save_path = os.path.abspath(os.path.expanduser(save_path)) + self.delete_when_train_finish = delete_when_train_finish + self.recover_fitlog = recovery_fitlog + if os.path.exists(os.path.expanduser(self.save_path)): + logger.info("The train will start from the checkpoint saved in {}.".format(self.save_path)) + if self.recover_fitlog: + states = torch.load(self.save_path) + if 'fitlog_log_dir' in states: + try: + import fitlog + log_dir = states['fitlog_log_dir'] + if 'fitlog_save_log_dir' in states: + log_dir = states['fitlog_save_log_dir'] + fitlog.set_log_dir(log_dir, new_log=True) + except: + print("Fail to recovery the fitlog states.") + + def on_train_begin(self): + """ + 当train开始时,且需要恢复上次训练时,会做以下的操作 + (1) 重新加载model权重 + (2) 重新加载optimizer的状态 + (3) 加载当前epoch数 + (4) 加载当前最佳evaluate的性能 + (5) (optional) 自动将fitlog设置到上次log出继续 + + :return: + """ + if os.path.exists(os.path.expanduser(self.save_path)): + states = torch.load(self.save_path) + self.model.load_state_dict(states['model']) + self.optimizer.load_state_dict(states['optimizer']) + self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 + self.trainer.step = states['step'] + if 'best_dev_epoch' in states: + self.trainer.best_dev_perf = states['best_dev_perf'] + self.trainer.best_dev_epoch = states['best_dev_epoch'] + self.trainer.best_dev_step = states['best_dev_step'] + self.trainer.best_metric_indicator = states['best_metric_indicator'] + + def on_epoch_end(self): + """ + 保存状态,使得结果可以被恢复 + + :param self: + :return: + """ + states = {} + states['model'] = {name:param.cpu() for name, param in self.model.state_dict().items()} + states['optimizer'] = self.optimizer.state_dict() + states['epoch'] = self.epoch + states['step'] = self.step + if self.trainer.best_dev_epoch is not None: + states['best_dev_epoch'] = self.trainer.best_dev_epoch + states['best_dev_perf'] = self.trainer.best_dev_perf + states['best_dev_step'] = self.trainer.best_dev_step + states['best_metric_indicator'] = self.trainer.best_metric_indicator + if self.recover_fitlog: + try: + import fitlog + if fitlog._logger._log_dir is not None: + states['fitlog_log_dir'] = fitlog._logger._log_dir + if fitlog._logger._save_log_dir is not None: + states['fitlog_save_log_dir'] = fitlog._logger._save_log_dir + except: + pass + torch.save(states, self.save_path) + + def on_train_end(self): + # 训练结束,根据情况删除保存的内容 + if self.delete_when_train_finish: + if os.path.exists(self.save_path): + os.remove(self.save_path) + + class WarmupCallback(Callback): """ learning rate按照一定的速率从0上升到设置的learning rate。 diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a0b93d9a..a5fea9bf 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -562,7 +562,6 @@ class Trainer(object): verbose=0, use_tqdm=self.test_use_tqdm) - self.step = 0 self.start_time = None # start timestamp if isinstance(callbacks, Callback): @@ -603,7 +602,8 @@ class Trainer(object): self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) start_time = time.time() self.logger.info("training epochs started " + self.start_time) - + self.step = 0 + self.epoch = 1 try: self.callback_manager.on_train_begin() self._train() @@ -642,14 +642,12 @@ class Trainer(object): from .utils import _pseudo_tqdm as inner_tqdm else: inner_tqdm = tqdm - self.step = 0 - self.epoch = 0 start = time.time() with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: self.pbar = pbar avg_loss = 0 self.batch_per_epoch = self.data_iterator.num_batches - for epoch in range(1, self.n_epochs + 1): + for epoch in range(self.epoch, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py index 4b5230c0..215c8bf2 100644 --- a/fastNLP/io/utils.py +++ b/fastNLP/io/utils.py @@ -26,8 +26,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 - :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 - 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 + 中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 :return: """ if isinstance(paths, (str, Path)): @@ -69,8 +69,8 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: for key, value in paths.items(): if isinstance(key, str) and isinstance(value, str): value = os.path.abspath(os.path.expanduser(value)) - if not os.path.isfile(value): - raise TypeError(f"{value} is not a valid file.") + if not os.path.exists(value): + raise TypeError(f"{value} is not a valid path.") paths[key] = value else: raise TypeError("All keys and values in paths should be str.") diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index db95a32d..9a11793f 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -41,8 +41,8 @@ class TestCallback(unittest.TestCase): self.tempdir = tempfile.mkdtemp() def tearDown(self): - pass - # shutil.rmtree(self.tempdir) + import shutil + shutil.rmtree(self.tempdir) def test_gradient_clip(self): data_set, model = prepare_env() @@ -145,7 +145,54 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=fitlog_callback, check_code_level=2) trainer.train() - + + def test_CheckPointCallback(self): + + from fastNLP import CheckPointCallback, Callback + from fastNLP import Tester + + class RaiseCallback(Callback): + def __init__(self, stop_step=10): + super().__init__() + self.stop_step = stop_step + + def on_backward_begin(self, loss): + if self.step > self.stop_step: + raise RuntimeError() + + data_set, model = prepare_env() + tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) + import fitlog + + fitlog.set_log_dir(self.tempdir) + tempfile_path = os.path.join(self.tempdir, 'chkt.pt') + callbacks = [CheckPointCallback(tempfile_path)] + + fitlog_callback = FitlogCallback(data_set, tester) + callbacks.append(fitlog_callback) + + callbacks.append(RaiseCallback(100)) + try: + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=callbacks, check_code_level=2) + trainer.train() + except: + pass + # 用下面的代码模拟重新运行 + data_set, model = prepare_env() + callbacks = [CheckPointCallback(tempfile_path)] + tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) + fitlog_callback = FitlogCallback(data_set, tester) + callbacks.append(fitlog_callback) + + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, + callbacks=callbacks, check_code_level=2) + trainer.train() + def test_save_model_callback(self): data_set, model = prepare_env() top = 3