From 05eb499eb893135d856d7eb440b1b1e1bd244956 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 20 Sep 2019 16:30:37 +0800 Subject: [PATCH] [bugfix] dist_trainer's save & load --- fastNLP/core/__init__.py | 3 +-- fastNLP/core/callback.py | 2 +- fastNLP/core/dist_trainer.py | 33 +++++++++++++++++---------------- test/core/test_dist_trainer.py | 4 +++- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index efee08b5..bea80097 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -49,7 +49,6 @@ __all__ = [ "WarmupCallback", 'SaveModelCallback', "EchoCallback", - "TesterCallback", "CallbackException", "EarlyStopError", @@ -79,7 +78,7 @@ from ._logger import logger from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ - TesterCallback, CallbackException, EarlyStopError + CallbackException, EarlyStopError from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 734c1269..fac1f1f4 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -63,7 +63,7 @@ __all__ = [ "WarmupCallback", "SaveModelCallback", "EchoCallback", - "TesterCallback", + "_TesterCallback", "CallbackException", "EarlyStopError" diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index c2804134..2451911d 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -17,7 +17,8 @@ from tqdm import tqdm from ._logger import logger from .batch import DataSetIter, BatchIter -from .callback import DistCallbackManager, CallbackException, _TesterCallback +from .callback import DistCallbackManager, CallbackException +from .callback import _TesterCallback from .dataset import DataSet from .losses import _prepare_losser from .optimizer import Optimizer @@ -174,13 +175,13 @@ class DistTrainer(): cb = _TesterCallback( dev_data, model, metrics, batch_size=batch_size_per_gpu, num_workers=num_workers) - self.test_manager.add_callback([cb], master=True) + self.test_manager.add_callback([cb], master=False) # Setup logging dist.barrier() self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') if self.save_path: - self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) + self.cp_save_path = os.path.join(self.save_path, 'checkpoints') else: self.cp_save_path = None @@ -286,11 +287,11 @@ class DistTrainer(): results['seconds'] = round(time.time() - start_time, 2) self.logger.info("###### Train finished ######") self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) - if load_best_model: - self.load_check_point('best_{}'.format(self.metric_key)) + if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): + self.load_check_point('best') finally: pass - + dist.barrier() return results def _train(self): @@ -417,29 +418,29 @@ class DistTrainer(): def load_check_point(self, name): path = os.path.join(self.cp_save_path, name) self.logger.info('reload best model from %s', path) - model_load = torch.load(path) + model_load = torch.load(path, map_location='cpu') if not isinstance(model_load, dict): model_load = model_load.state_dict() - self.model.load_state_dict(model_load) + self.model.module.load_state_dict(model_load) def _do_validation(self): self.callback_manager.on_valid_begin() + # do evaluate on all nodes eval_res = self.test_manager.on_valid_begin() eval_res = list(filter(lambda x: x is not None, eval_res)) if len(eval_res): eval_res, is_better = list(zip(*eval_res)) else: eval_res, is_better = None, None + # save better model on master node + if self.is_master and is_better is not None and self.cp_save_path: + for i, better_flag in enumerate(is_better): + if better_flag: + # TODO to support multiple datasets to evaluate + self.save_check_point('best') + break self.callback_manager.on_valid_end( eval_res, self.metric_key, self.optimizer, is_better) - - # save better model - for i, better_flag in enumerate(is_better): - if better_flag: - # TODO to support multiple datasets to evaluate - name = 'best_{}'.format(self.metric_key) - self.save_check_point(name) - break dist.barrier() def close(self): diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py index c6879634..03f613e1 100644 --- a/test/core/test_dist_trainer.py +++ b/test/core/test_dist_trainer.py @@ -130,12 +130,14 @@ class TestDistTrainer(unittest.TestCase): train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, - metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, + metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=self.save_path, ) trainer.train() """ # 应该正确运行 """ + if trainer.is_master and os.path.exists(self.save_path): + shutil.rmtree(self.save_path) def run_dist(self, run_id): if torch.cuda.is_available():