Browse Source

[bugfix] dist_trainer's save & load

tags/v0.4.10
yunfan 5 years ago
parent
commit
05eb499eb8
4 changed files with 22 additions and 20 deletions
  1. +1
    -2
      fastNLP/core/__init__.py
  2. +1
    -1
      fastNLP/core/callback.py
  3. +17
    -16
      fastNLP/core/dist_trainer.py
  4. +3
    -1
      test/core/test_dist_trainer.py

+ 1
- 2
fastNLP/core/__init__.py View File

@@ -49,7 +49,6 @@ __all__ = [
"WarmupCallback", "WarmupCallback",
'SaveModelCallback', 'SaveModelCallback',
"EchoCallback", "EchoCallback",
"TesterCallback",
"CallbackException", "CallbackException",
"EarlyStopError", "EarlyStopError",
@@ -79,7 +78,7 @@ from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
TesterCallback, CallbackException, EarlyStopError
CallbackException, EarlyStopError
from .const import Const from .const import Const
from .dataset import DataSet from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder


+ 1
- 1
fastNLP/core/callback.py View File

@@ -63,7 +63,7 @@ __all__ = [
"WarmupCallback", "WarmupCallback",
"SaveModelCallback", "SaveModelCallback",
"EchoCallback", "EchoCallback",
"TesterCallback",
"_TesterCallback",
"CallbackException", "CallbackException",
"EarlyStopError" "EarlyStopError"


+ 17
- 16
fastNLP/core/dist_trainer.py View File

@@ -17,7 +17,8 @@ from tqdm import tqdm


from ._logger import logger from ._logger import logger
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, _TesterCallback
from .callback import DistCallbackManager, CallbackException
from .callback import _TesterCallback
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
@@ -174,13 +175,13 @@ class DistTrainer():
cb = _TesterCallback( cb = _TesterCallback(
dev_data, model, metrics, dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_workers) batch_size=batch_size_per_gpu, num_workers=num_workers)
self.test_manager.add_callback([cb], master=True)
self.test_manager.add_callback([cb], master=False)


# Setup logging # Setup logging
dist.barrier() dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
if self.save_path: if self.save_path:
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time)
self.cp_save_path = os.path.join(self.save_path, 'checkpoints')
else: else:
self.cp_save_path = None self.cp_save_path = None


@@ -286,11 +287,11 @@ class DistTrainer():
results['seconds'] = round(time.time() - start_time, 2) results['seconds'] = round(time.time() - start_time, 2)
self.logger.info("###### Train finished ######") self.logger.info("###### Train finished ######")
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
if load_best_model:
self.load_check_point('best_{}'.format(self.metric_key))
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks):
self.load_check_point('best')
finally: finally:
pass pass
dist.barrier()
return results return results


def _train(self): def _train(self):
@@ -417,29 +418,29 @@ class DistTrainer():
def load_check_point(self, name): def load_check_point(self, name):
path = os.path.join(self.cp_save_path, name) path = os.path.join(self.cp_save_path, name)
self.logger.info('reload best model from %s', path) self.logger.info('reload best model from %s', path)
model_load = torch.load(path)
model_load = torch.load(path, map_location='cpu')
if not isinstance(model_load, dict): if not isinstance(model_load, dict):
model_load = model_load.state_dict() model_load = model_load.state_dict()
self.model.load_state_dict(model_load)
self.model.module.load_state_dict(model_load)


def _do_validation(self): def _do_validation(self):
self.callback_manager.on_valid_begin() self.callback_manager.on_valid_begin()
# do evaluate on all nodes
eval_res = self.test_manager.on_valid_begin() eval_res = self.test_manager.on_valid_begin()
eval_res = list(filter(lambda x: x is not None, eval_res)) eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res): if len(eval_res):
eval_res, is_better = list(zip(*eval_res)) eval_res, is_better = list(zip(*eval_res))
else: else:
eval_res, is_better = None, None eval_res, is_better = None, None
# save better model on master node
if self.is_master and is_better is not None and self.cp_save_path:
for i, better_flag in enumerate(is_better):
if better_flag:
# TODO to support multiple datasets to evaluate
self.save_check_point('best')
break
self.callback_manager.on_valid_end( self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better) eval_res, self.metric_key, self.optimizer, is_better)

# save better model
for i, better_flag in enumerate(is_better):
if better_flag:
# TODO to support multiple datasets to evaluate
name = 'best_{}'.format(self.metric_key)
self.save_check_point(name)
break
dist.barrier() dist.barrier()


def close(self): def close(self):


+ 3
- 1
test/core/test_dist_trainer.py View File

@@ -130,12 +130,14 @@ class TestDistTrainer(unittest.TestCase):
train_set, model, optimizer=SGD(lr=0.1), train_set, model, optimizer=SGD(lr=0.1),
loss=BCELoss(pred="predict", target="y"), loss=BCELoss(pred="predict", target="y"),
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=self.save_path,
) )
trainer.train() trainer.train()
""" """
# 应该正确运行 # 应该正确运行
""" """
if trainer.is_master and os.path.exists(self.save_path):
shutil.rmtree(self.save_path)


def run_dist(self, run_id): def run_dist(self, run_id):
if torch.cuda.is_available(): if torch.cuda.is_available():


Loading…
Cancel
Save